1package goose
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "path/filepath"
9 "strconv"
10 "strings"
11 "time"
12
13 "github.com/pressly/goose/v3/internal/sqlparser"
14)
15
16// NewGoMigration creates a new Go migration.
17//
18// Both up and down functions may be nil, in which case the migration will be recorded in the
19// versions table but no functions will be run. This is useful for recording (up) or deleting (down)
20// a version without running any functions. See [GoFunc] for more details.
21func NewGoMigration(version int64, up, down *GoFunc) *Migration {
22 m := &Migration{
23 Type: TypeGo,
24 Registered: true,
25 Version: version,
26 Next: -1, Previous: -1,
27 goUp: &GoFunc{Mode: TransactionEnabled},
28 goDown: &GoFunc{Mode: TransactionEnabled},
29 construct: true,
30 }
31 updateMode := func(f *GoFunc) *GoFunc {
32 // infer mode from function
33 if f.Mode == 0 {
34 if f.RunTx != nil && f.RunDB == nil {
35 f.Mode = TransactionEnabled
36 }
37 if f.RunTx == nil && f.RunDB != nil {
38 f.Mode = TransactionDisabled
39 }
40 // Always default to TransactionEnabled if both functions are nil. This is the most
41 // common use case.
42 if f.RunDB == nil && f.RunTx == nil {
43 f.Mode = TransactionEnabled
44 }
45 }
46 return f
47 }
48 // To maintain backwards compatibility, we set ALL legacy functions. In a future major version,
49 // we will remove these fields in favor of [GoFunc].
50 //
51 // Note, this function does not do any validation. Validation is lazily done when the migration
52 // is registered.
53 if up != nil {
54 m.goUp = updateMode(up)
55
56 if up.RunDB != nil {
57 m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error
58 m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error
59 }
60 if up.RunTx != nil {
61 m.UseTx = true
62 m.UpFnContext = up.RunTx // func(context.Context, *sql.Tx) error
63 m.UpFn = withoutContext(up.RunTx) // func(*sql.Tx) error
64 }
65 }
66 if down != nil {
67 m.goDown = updateMode(down)
68
69 if down.RunDB != nil {
70 m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error
71 m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error
72 }
73 if down.RunTx != nil {
74 m.UseTx = true
75 m.DownFnContext = down.RunTx // func(context.Context, *sql.Tx) error
76 m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error
77 }
78 }
79 return m
80}
81
82// Migration struct represents either a SQL or Go migration.
83//
84// Avoid constructing migrations manually, use [NewGoMigration] function.
85type Migration struct {
86 Type MigrationType
87 Version int64
88 // Source is the path to the .sql script or .go file. It may be empty for Go migrations that
89 // have been registered globally and don't have a source file.
90 Source string
91
92 UpFnContext, DownFnContext GoMigrationContext
93 UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
94
95 // These fields will be removed in a future major version. They are here for backwards
96 // compatibility and are an implementation detail.
97 Registered bool
98 UseTx bool
99 Next int64 // next version, or -1 if none
100 Previous int64 // previous version, -1 if none
101
102 // We still save the non-context versions in the struct in case someone is using them. Goose
103 // does not use these internally anymore in favor of the context-aware versions. These fields
104 // will be removed in a future major version.
105
106 UpFn GoMigration // Deprecated: use UpFnContext instead.
107 DownFn GoMigration // Deprecated: use DownFnContext instead.
108 UpFnNoTx GoMigrationNoTx // Deprecated: use UpFnNoTxContext instead.
109 DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead.
110
111 noVersioning bool
112
113 // These fields are used internally by goose and users are not expected to set them. Instead,
114 // use [NewGoMigration] to create a new go migration.
115 construct bool
116 goUp, goDown *GoFunc
117
118 sql sqlMigration
119}
120
121type sqlMigration struct {
122 // The Parsed field is used to track whether the SQL migration has been parsed. It serves as an
123 // optimization to avoid parsing migrations that may never be needed. Typically, migrations are
124 // incremental, and users often run only the most recent ones, making parsing of prior
125 // migrations unnecessary in most cases.
126 Parsed bool
127
128 // Parsed must be set to true before the following fields are used.
129 UseTx bool
130 Up []string
131 Down []string
132}
133
134// GoFunc represents a Go migration function.
135type GoFunc struct {
136 // Exactly one of these must be set, or both must be nil.
137 RunTx func(ctx context.Context, tx *sql.Tx) error
138 // -- OR --
139 RunDB func(ctx context.Context, db *sql.DB) error
140
141 // Mode is the transaction mode for the migration. When one of the run functions is set, the
142 // mode will be inferred from the function and the field is ignored. Users do not need to set
143 // this field when supplying a run function.
144 //
145 // If both run functions are nil, the mode defaults to TransactionEnabled. The use case for nil
146 // functions is to record a version in the version table without invoking a Go migration
147 // function.
148 //
149 // The only time this field is required is if BOTH run functions are nil AND you want to
150 // override the default transaction mode.
151 Mode TransactionMode
152}
153
154// TransactionMode represents the possible transaction modes for a migration.
155type TransactionMode int
156
157const (
158 TransactionEnabled TransactionMode = iota + 1
159 TransactionDisabled
160)
161
162func (m TransactionMode) String() string {
163 switch m {
164 case TransactionEnabled:
165 return "transaction_enabled"
166 case TransactionDisabled:
167 return "transaction_disabled"
168 default:
169 return fmt.Sprintf("unknown transaction mode (%d)", m)
170 }
171}
172
173// MigrationRecord struct.
174//
175// Deprecated: unused and will be removed in a future major version.
176type MigrationRecord struct {
177 VersionID int64
178 TStamp time.Time
179 IsApplied bool // was this a result of up() or down()
180}
181
182func (m *Migration) String() string {
183 return fmt.Sprint(m.Source)
184}
185
186// Up runs an up migration.
187func (m *Migration) Up(db *sql.DB) error {
188 ctx := context.Background()
189 return m.UpContext(ctx, db)
190}
191
192// UpContext runs an up migration.
193func (m *Migration) UpContext(ctx context.Context, db *sql.DB) error {
194 if err := m.run(ctx, db, true); err != nil {
195 return err
196 }
197 return nil
198}
199
200// Down runs a down migration.
201func (m *Migration) Down(db *sql.DB) error {
202 ctx := context.Background()
203 return m.DownContext(ctx, db)
204}
205
206// DownContext runs a down migration.
207func (m *Migration) DownContext(ctx context.Context, db *sql.DB) error {
208 if err := m.run(ctx, db, false); err != nil {
209 return err
210 }
211 return nil
212}
213
214func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
215 switch filepath.Ext(m.Source) {
216 case ".sql":
217 f, err := baseFS.Open(m.Source)
218 if err != nil {
219 return fmt.Errorf("ERROR %v: failed to open SQL migration file: %w", filepath.Base(m.Source), err)
220 }
221 defer f.Close()
222
223 statements, useTx, err := sqlparser.ParseSQLMigration(f, sqlparser.FromBool(direction), verbose)
224 if err != nil {
225 return fmt.Errorf("ERROR %v: failed to parse SQL migration file: %w", filepath.Base(m.Source), err)
226 }
227
228 start := time.Now()
229 if err := runSQLMigration(ctx, db, statements, useTx, m.Version, direction, m.noVersioning); err != nil {
230 return fmt.Errorf("ERROR %v: failed to run SQL migration: %w", filepath.Base(m.Source), err)
231 }
232 finish := truncateDuration(time.Since(start))
233
234 if len(statements) > 0 {
235 log.Printf("OK %s (%s)", filepath.Base(m.Source), finish)
236 } else {
237 log.Printf("EMPTY %s (%s)", filepath.Base(m.Source), finish)
238 }
239
240 case ".go":
241 if !m.Registered {
242 return fmt.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
243 }
244 start := time.Now()
245 var empty bool
246 if m.UseTx {
247 // Run go-based migration inside a tx.
248 fn := m.DownFnContext
249 if direction {
250 fn = m.UpFnContext
251 }
252 empty = (fn == nil)
253 if err := runGoMigration(
254 ctx,
255 db,
256 fn,
257 m.Version,
258 direction,
259 !m.noVersioning,
260 ); err != nil {
261 return fmt.Errorf("ERROR go migration: %q: %w", filepath.Base(m.Source), err)
262 }
263 } else {
264 // Run go-based migration outside a tx.
265 fn := m.DownFnNoTxContext
266 if direction {
267 fn = m.UpFnNoTxContext
268 }
269 empty = (fn == nil)
270 if err := runGoMigrationNoTx(
271 ctx,
272 db,
273 fn,
274 m.Version,
275 direction,
276 !m.noVersioning,
277 ); err != nil {
278 return fmt.Errorf("ERROR go migration no tx: %q: %w", filepath.Base(m.Source), err)
279 }
280 }
281 finish := truncateDuration(time.Since(start))
282 if !empty {
283 log.Printf("OK %s (%s)", filepath.Base(m.Source), finish)
284 } else {
285 log.Printf("EMPTY %s (%s)", filepath.Base(m.Source), finish)
286 }
287 }
288 return nil
289}
290
291func runGoMigrationNoTx(
292 ctx context.Context,
293 db *sql.DB,
294 fn GoMigrationNoTxContext,
295 version int64,
296 direction bool,
297 recordVersion bool,
298) error {
299 if fn != nil {
300 // Run go migration function.
301 if err := fn(ctx, db); err != nil {
302 return fmt.Errorf("failed to run go migration: %w", err)
303 }
304 }
305 if recordVersion {
306 return insertOrDeleteVersionNoTx(ctx, db, version, direction)
307 }
308 return nil
309}
310
311func runGoMigration(
312 ctx context.Context,
313 db *sql.DB,
314 fn GoMigrationContext,
315 version int64,
316 direction bool,
317 recordVersion bool,
318) error {
319 if fn == nil && !recordVersion {
320 return nil
321 }
322 tx, err := db.BeginTx(ctx, nil)
323 if err != nil {
324 return fmt.Errorf("failed to begin transaction: %w", err)
325 }
326 if fn != nil {
327 // Run go migration function.
328 if err := fn(ctx, tx); err != nil {
329 _ = tx.Rollback()
330 return fmt.Errorf("failed to run go migration: %w", err)
331 }
332 }
333 if recordVersion {
334 if err := insertOrDeleteVersion(ctx, tx, version, direction); err != nil {
335 _ = tx.Rollback()
336 return fmt.Errorf("failed to update version: %w", err)
337 }
338 }
339 if err := tx.Commit(); err != nil {
340 return fmt.Errorf("failed to commit transaction: %w", err)
341 }
342 return nil
343}
344
345func insertOrDeleteVersion(ctx context.Context, tx *sql.Tx, version int64, direction bool) error {
346 if direction {
347 return store.InsertVersion(ctx, tx, TableName(), version)
348 }
349 return store.DeleteVersion(ctx, tx, TableName(), version)
350}
351
352func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, direction bool) error {
353 if direction {
354 return store.InsertVersionNoTx(ctx, db, TableName(), version)
355 }
356 return store.DeleteVersionNoTx(ctx, db, TableName(), version)
357}
358
359// NumericComponent parses the version from the migration file name.
360//
361// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of
362// migration, either .sql or .go.
363func NumericComponent(filename string) (int64, error) {
364 base := filepath.Base(filename)
365 if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
366 return 0, errors.New("migration file does not have .sql or .go file extension")
367 }
368 idx := strings.Index(base, "_")
369 if idx < 0 {
370 return 0, errors.New("no filename separator '_' found")
371 }
372 n, err := strconv.ParseInt(base[:idx], 10, 64)
373 if err != nil {
374 return 0, fmt.Errorf("failed to parse version from migration file: %s: %w", base, err)
375 }
376 if n < 1 {
377 return 0, errors.New("migration version must be greater than zero")
378 }
379 return n, nil
380}
381
382func truncateDuration(d time.Duration) time.Duration {
383 for _, v := range []time.Duration{
384 time.Second,
385 time.Millisecond,
386 time.Microsecond,
387 } {
388 if d > v {
389 return d.Round(v / time.Duration(100))
390 }
391 }
392 return d
393}
394
395// ref returns a string that identifies the migration. This is used for logging and error messages.
396func (m *Migration) ref() string {
397 return fmt.Sprintf("(type:%s,version:%d)", m.Type, m.Version)
398}