1package goose
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "io/fs"
9 "math"
10 "strconv"
11 "strings"
12 "sync"
13
14 "github.com/pressly/goose/v3/database"
15 "github.com/pressly/goose/v3/internal/controller"
16 "github.com/pressly/goose/v3/internal/gooseutil"
17 "github.com/pressly/goose/v3/internal/sqlparser"
18 "go.uber.org/multierr"
19)
20
21// Provider is a goose migration provider.
22type Provider struct {
23 // mu protects all accesses to the provider and must be held when calling operations on the
24 // database.
25 mu sync.Mutex
26
27 db *sql.DB
28 store *controller.StoreController
29 versionTableOnce sync.Once
30
31 fsys fs.FS
32 cfg config
33
34 // migrations are ordered by version in ascending order. This list will never be empty and
35 // contains all migrations known to the provider.
36 migrations []*Migration
37}
38
39// NewProvider returns a new goose provider.
40//
41// The caller is responsible for matching the database dialect with the database/sql driver. For
42// example, if the database dialect is "postgres", the database/sql driver could be
43// github.com/lib/pq or github.com/jackc/pgx. Each dialect has a corresponding [database.Dialect]
44// constant backed by a default [database.Store] implementation. For more advanced use cases, such
45// as using a custom table name or supplying a custom store implementation, see [WithStore].
46//
47// fsys is the filesystem used to read migration files, but may be nil. Most users will want to use
48// [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
49// However, it is possible to use a different "filesystem", such as [embed.FS] or filter out
50// migrations using [fs.Sub].
51//
52// See [ProviderOption] for more information on configuring the provider.
53//
54// Unless otherwise specified, all methods on Provider are safe for concurrent use.
55func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
56 if db == nil {
57 return nil, errors.New("db must not be nil")
58 }
59 if fsys == nil {
60 fsys = noopFS{}
61 }
62 cfg := config{
63 registered: make(map[int64]*Migration),
64 excludePaths: make(map[string]bool),
65 excludeVersions: make(map[int64]bool),
66 logger: &stdLogger{},
67 }
68 for _, opt := range opts {
69 if err := opt.apply(&cfg); err != nil {
70 return nil, err
71 }
72 }
73 // Allow users to specify a custom store implementation, but only if they don't specify a
74 // dialect. If they specify a dialect, we'll use the default store implementation.
75 if dialect == "" && cfg.store == nil {
76 return nil, errors.New("dialect must not be empty")
77 }
78 if dialect != "" && cfg.store != nil {
79 return nil, errors.New("dialect must be empty when using a custom store implementation")
80 }
81 var store database.Store
82 if dialect != "" {
83 var err error
84 store, err = database.NewStore(dialect, DefaultTablename)
85 if err != nil {
86 return nil, err
87 }
88 } else {
89 store = cfg.store
90 }
91 if store.Tablename() == "" {
92 return nil, errors.New("invalid store implementation: table name must not be empty")
93 }
94 return newProvider(db, store, fsys, cfg, registeredGoMigrations /* global */)
95}
96
97func newProvider(
98 db *sql.DB,
99 store database.Store,
100 fsys fs.FS,
101 cfg config,
102 global map[int64]*Migration,
103) (*Provider, error) {
104 // Collect migrations from the filesystem and merge with registered migrations.
105 //
106 // Note, we don't parse SQL migrations here. They are parsed lazily when required.
107
108 // feat(mf): we could add a flag to parse SQL migrations eagerly. This would allow us to return
109 // an error if there are any SQL parsing errors. This adds a bit overhead to startup though, so
110 // we should make it optional.
111 filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions)
112 if err != nil {
113 return nil, err
114 }
115 versionToGoMigration := make(map[int64]*Migration)
116 // Add user-registered Go migrations from the provider.
117 for version, m := range cfg.registered {
118 versionToGoMigration[version] = m
119 }
120 // Skip adding global Go migrations if explicitly disabled.
121 if cfg.disableGlobalRegistry {
122 // TODO(mf): let's add a warn-level log here to inform users if len(global) > 0. Would like
123 // to add this once we're on go1.21 and leverage the new slog package.
124 } else {
125 for version, m := range global {
126 if _, ok := versionToGoMigration[version]; ok {
127 return nil, fmt.Errorf("global go migration conflicts with provider-registered go migration with version %d", version)
128 }
129 versionToGoMigration[version] = m
130 }
131 }
132 // At this point we have all registered unique Go migrations (if any). We need to merge them
133 // with SQL migrations from the filesystem.
134 migrations, err := merge(filesystemSources, versionToGoMigration)
135 if err != nil {
136 return nil, err
137 }
138 if len(migrations) == 0 {
139 return nil, ErrNoMigrations
140 }
141 return &Provider{
142 db: db,
143 fsys: fsys,
144 cfg: cfg,
145 store: controller.NewStoreController(store),
146 migrations: migrations,
147 }, nil
148}
149
150// Status returns the status of all migrations, merging the list of migrations from the database and
151// filesystem. The returned items are ordered by version, in ascending order.
152func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
153 return p.status(ctx)
154}
155
156// HasPending returns true if there are pending migrations to apply, otherwise, it returns false. If
157// out-of-order migrations are disabled, yet some are detected, this method returns an error.
158//
159// Note, this method will not use a SessionLocker if one is configured. This allows callers to check
160// for pending migrations without blocking or being blocked by other operations.
161func (p *Provider) HasPending(ctx context.Context) (bool, error) {
162 return p.hasPending(ctx)
163}
164
165// GetVersions returns the max database version and the target version to migrate to.
166//
167// Note, this method will not use a SessionLocker if one is configured. This allows callers to check
168// for versions without blocking or being blocked by other operations.
169func (p *Provider) GetVersions(ctx context.Context) (current, target int64, err error) {
170 return p.getVersions(ctx)
171}
172
173// GetDBVersion returns the highest version recorded in the database, regardless of the order in
174// which migrations were applied. For example, if migrations were applied out of order (1,4,2,3),
175// this method returns 4. If no migrations have been applied, it returns 0.
176func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
177 if p.cfg.disableVersioning {
178 return -1, errors.New("getting database version not supported when versioning is disabled")
179 }
180 return p.getDBMaxVersion(ctx, nil)
181}
182
183// ListSources returns a list of all migration sources known to the provider, sorted in ascending
184// order by version. The path field may be empty for manually registered migrations, such as Go
185// migrations registered using the [WithGoMigrations] option.
186func (p *Provider) ListSources() []*Source {
187 sources := make([]*Source, 0, len(p.migrations))
188 for _, m := range p.migrations {
189 sources = append(sources, &Source{
190 Type: m.Type,
191 Path: m.Source,
192 Version: m.Version,
193 })
194 }
195 return sources
196}
197
198// Ping attempts to ping the database to verify a connection is available.
199func (p *Provider) Ping(ctx context.Context) error {
200 return p.db.PingContext(ctx)
201}
202
203// Close closes the database connection initially supplied to the provider.
204func (p *Provider) Close() error {
205 return p.db.Close()
206}
207
208// ApplyVersion applies exactly one migration for the specified version. If there is no migration
209// available for the specified version, this method returns [ErrVersionNotFound]. If the migration
210// has already been applied, this method returns [ErrAlreadyApplied].
211//
212// The direction parameter determines the migration direction: true for up migration and false for
213// down migration.
214func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
215 res, err := p.apply(ctx, version, direction)
216 if err != nil {
217 return nil, err
218 }
219 // This should never happen, we must return exactly one result.
220 if len(res) != 1 {
221 versions := make([]string, 0, len(res))
222 for _, r := range res {
223 versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
224 }
225 return nil, fmt.Errorf(
226 "unexpected number of migrations applied running apply, expecting exactly one result: %v",
227 strings.Join(versions, ","),
228 )
229 }
230 return res[0], nil
231}
232
233// Up applies all pending migrations. If there are no new migrations to apply, this method returns
234// empty list and nil error.
235func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
236 hasPending, err := p.HasPending(ctx)
237 if err != nil {
238 return nil, err
239 }
240 if !hasPending {
241 return nil, nil
242 }
243 return p.up(ctx, false, math.MaxInt64)
244}
245
246// UpByOne applies the next pending migration. If there is no next migration to apply, this method
247// returns [ErrNoNextVersion].
248func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
249 hasPending, err := p.HasPending(ctx)
250 if err != nil {
251 return nil, err
252 }
253 if !hasPending {
254 return nil, ErrNoNextVersion
255 }
256 res, err := p.up(ctx, true, math.MaxInt64)
257 if err != nil {
258 return nil, err
259 }
260 if len(res) == 0 {
261 return nil, ErrNoNextVersion
262 }
263 // This should never happen, we must return exactly one result.
264 if len(res) != 1 {
265 versions := make([]string, 0, len(res))
266 for _, r := range res {
267 versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
268 }
269 return nil, fmt.Errorf(
270 "unexpected number of migrations applied running up-by-one, expecting exactly one result: %v",
271 strings.Join(versions, ","),
272 )
273 }
274 return res[0], nil
275}
276
277// UpTo applies all pending migrations up to, and including, the specified version. If there are no
278// migrations to apply, this method returns empty list and nil error.
279//
280// For example, if there are three new migrations (9,10,11) and the current database version is 8
281// with a requested version of 10, only versions 9,10 will be applied.
282func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
283 hasPending, err := p.HasPending(ctx)
284 if err != nil {
285 return nil, err
286 }
287 if !hasPending {
288 return nil, nil
289 }
290 return p.up(ctx, false, version)
291}
292
293// Down rolls back the most recently applied migration. If there are no migrations to rollback, this
294// method returns [ErrNoNextVersion].
295//
296// Note, migrations are rolled back in the order they were applied. And not in the reverse order of
297// the migration version. This only applies in scenarios where migrations are allowed to be applied
298// out of order.
299func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
300 res, err := p.down(ctx, true, 0)
301 if err != nil {
302 return nil, err
303 }
304 if len(res) == 0 {
305 return nil, ErrNoNextVersion
306 }
307 // This should never happen, we must return exactly one result.
308 if len(res) != 1 {
309 versions := make([]string, 0, len(res))
310 for _, r := range res {
311 versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
312 }
313 return nil, fmt.Errorf(
314 "unexpected number of migrations applied running down, expecting exactly one result: %v",
315 strings.Join(versions, ","),
316 )
317 }
318 return res[0], nil
319}
320
321// DownTo rolls back all migrations down to, but not including, the specified version.
322//
323// For example, if the current database version is 11,10,9... and the requested version is 9, only
324// migrations 11, 10 will be rolled back.
325//
326// Note, migrations are rolled back in the order they were applied. And not in the reverse order of
327// the migration version. This only applies in scenarios where migrations are allowed to be applied
328// out of order.
329func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
330 if version < 0 {
331 return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
332 }
333 return p.down(ctx, false, version)
334}
335
336// *** Internal methods ***
337
338func (p *Provider) up(
339 ctx context.Context,
340 byOne bool,
341 version int64,
342) (_ []*MigrationResult, retErr error) {
343 if version < 1 {
344 return nil, errInvalidVersion
345 }
346 conn, cleanup, err := p.initialize(ctx, true)
347 if err != nil {
348 return nil, fmt.Errorf("failed to initialize: %w", err)
349 }
350 defer func() {
351 retErr = multierr.Append(retErr, cleanup())
352 }()
353
354 if len(p.migrations) == 0 {
355 return nil, nil
356 }
357 var apply []*Migration
358 if p.cfg.disableVersioning {
359 if byOne {
360 return nil, errors.New("up-by-one not supported when versioning is disabled")
361 }
362 apply = p.migrations
363 } else {
364 // optimize(mf): Listing all migrations from the database isn't great.
365 //
366 // The ideal implementation would be to query for the current max version and then apply
367 // migrations greater than that version. However, a nice property of the current
368 // implementation is that we can make stronger guarantees about unapplied migrations.
369 //
370 // In cases where users do not use out-of-order migrations, we want to surface an error if
371 // there are older unapplied migrations. See https://github.com/pressly/goose/issues/761 for
372 // more details.
373 //
374 // And in cases where users do use out-of-order migrations, we need to build a list of older
375 // migrations that need to be applied, so we need to query for all migrations anyways.
376 dbMigrations, err := p.store.ListMigrations(ctx, conn)
377 if err != nil {
378 return nil, err
379 }
380 if len(dbMigrations) == 0 {
381 return nil, errMissingZeroVersion
382 }
383 versions, err := gooseutil.UpVersions(
384 getVersionsFromMigrations(p.migrations), // fsys versions
385 getVersionsFromListMigrations(dbMigrations), // db versions
386 version,
387 p.cfg.allowMissing,
388 )
389 if err != nil {
390 return nil, err
391 }
392 for _, v := range versions {
393 m, err := p.getMigration(v)
394 if err != nil {
395 return nil, err
396 }
397 apply = append(apply, m)
398 }
399 }
400 return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, byOne)
401}
402
403func (p *Provider) down(
404 ctx context.Context,
405 byOne bool,
406 version int64,
407) (_ []*MigrationResult, retErr error) {
408 conn, cleanup, err := p.initialize(ctx, true)
409 if err != nil {
410 return nil, fmt.Errorf("failed to initialize: %w", err)
411 }
412 defer func() {
413 retErr = multierr.Append(retErr, cleanup())
414 }()
415
416 if len(p.migrations) == 0 {
417 return nil, nil
418 }
419 if p.cfg.disableVersioning {
420 var downMigrations []*Migration
421 if byOne {
422 last := p.migrations[len(p.migrations)-1]
423 downMigrations = []*Migration{last}
424 } else {
425 downMigrations = p.migrations
426 }
427 return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, byOne)
428 }
429 dbMigrations, err := p.store.ListMigrations(ctx, conn)
430 if err != nil {
431 return nil, err
432 }
433 if len(dbMigrations) == 0 {
434 return nil, errMissingZeroVersion
435 }
436 // We never migrate the zero version down.
437 if dbMigrations[0].Version == 0 {
438 p.printf("no migrations to run, current version: 0")
439 return nil, nil
440 }
441 var apply []*Migration
442 for _, dbMigration := range dbMigrations {
443 if dbMigration.Version <= version {
444 break
445 }
446 m, err := p.getMigration(dbMigration.Version)
447 if err != nil {
448 return nil, err
449 }
450 apply = append(apply, m)
451 }
452 return p.runMigrations(ctx, conn, apply, sqlparser.DirectionDown, byOne)
453}
454
455func (p *Provider) apply(
456 ctx context.Context,
457 version int64,
458 direction bool,
459) (_ []*MigrationResult, retErr error) {
460 if version < 1 {
461 return nil, errInvalidVersion
462 }
463 m, err := p.getMigration(version)
464 if err != nil {
465 return nil, err
466 }
467 conn, cleanup, err := p.initialize(ctx, true)
468 if err != nil {
469 return nil, fmt.Errorf("failed to initialize: %w", err)
470 }
471 defer func() {
472 retErr = multierr.Append(retErr, cleanup())
473 }()
474
475 result, err := p.store.GetMigration(ctx, conn, version)
476 if err != nil && !errors.Is(err, database.ErrVersionNotFound) {
477 return nil, err
478 }
479 // There are a few states here:
480 // 1. direction is up
481 // a. migration is applied, this is an error (ErrAlreadyApplied)
482 // b. migration is not applied, apply it
483 if direction && result != nil {
484 return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied)
485 }
486 // 2. direction is down
487 // a. migration is applied, rollback
488 // b. migration is not applied, this is an error (ErrNotApplied)
489 if !direction && result == nil {
490 return nil, fmt.Errorf("version %d: %w", version, ErrNotApplied)
491 }
492 d := sqlparser.DirectionDown
493 if direction {
494 d = sqlparser.DirectionUp
495 }
496 return p.runMigrations(ctx, conn, []*Migration{m}, d, true)
497}
498
499func (p *Provider) getVersions(ctx context.Context) (current, target int64, retErr error) {
500 conn, cleanup, err := p.initialize(ctx, false)
501 if err != nil {
502 return -1, -1, fmt.Errorf("failed to initialize: %w", err)
503 }
504 defer func() {
505 retErr = multierr.Append(retErr, cleanup())
506 }()
507
508 target = p.migrations[len(p.migrations)-1].Version
509
510 // If versioning is disabled, we always have pending migrations and the target version is the
511 // last migration.
512 if p.cfg.disableVersioning {
513 return -1, target, nil
514 }
515
516 current, err = p.store.GetLatestVersion(ctx, conn)
517 if err != nil {
518 if errors.Is(err, database.ErrVersionNotFound) {
519 return -1, target, errMissingZeroVersion
520 }
521 return -1, target, err
522 }
523 return current, target, nil
524}
525
526func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
527 conn, cleanup, err := p.initialize(ctx, false)
528 if err != nil {
529 return false, fmt.Errorf("failed to initialize: %w", err)
530 }
531 defer func() {
532 retErr = multierr.Append(retErr, cleanup())
533 }()
534
535 // If versioning is disabled, we always have pending migrations.
536 if p.cfg.disableVersioning {
537 return true, nil
538 }
539
540 // List all migrations from the database. Careful, optimizations here can lead to subtle bugs.
541 // We have 2 important cases to consider:
542 //
543 // 1. Users have enabled out-of-order migrations, in which case we need to check if any
544 // migrations are missing and report that there are pending migrations. Do not surface an
545 // error because this is a valid state.
546 //
547 // 2. Users have disabled out-of-order migrations (default), in which case we need to check if all
548 // migrations have been applied. We cannot check for the highest applied version because we lose the
549 // ability to surface an error if an out-of-order migration was introduced. It would be silently
550 // ignored and the user would not know that they have unapplied migrations.
551 //
552 // Maybe we could consider adding a flag to the provider such as IgnoreMissing, which would
553 // allow silently ignoring missing migrations. This would be useful for users that have built
554 // checks that prevent missing migrations from being introduced.
555
556 dbMigrations, err := p.store.ListMigrations(ctx, conn)
557 if err != nil {
558 return false, err
559 }
560 apply, err := gooseutil.UpVersions(
561 getVersionsFromMigrations(p.migrations), // fsys versions
562 getVersionsFromListMigrations(dbMigrations), // db versions
563 math.MaxInt64,
564 p.cfg.allowMissing,
565 )
566 if err != nil {
567 return false, err
568 }
569 return len(apply) > 0, nil
570}
571
572func getVersionsFromMigrations(in []*Migration) []int64 {
573 out := make([]int64, 0, len(in))
574 for _, m := range in {
575 out = append(out, m.Version)
576 }
577 return out
578
579}
580
581func getVersionsFromListMigrations(in []*database.ListMigrationsResult) []int64 {
582 out := make([]int64, 0, len(in))
583 for _, m := range in {
584 out = append(out, m.Version)
585 }
586 return out
587
588}
589
590func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
591 conn, cleanup, err := p.initialize(ctx, true)
592 if err != nil {
593 return nil, fmt.Errorf("failed to initialize: %w", err)
594 }
595 defer func() {
596 retErr = multierr.Append(retErr, cleanup())
597 }()
598
599 status := make([]*MigrationStatus, 0, len(p.migrations))
600 for _, m := range p.migrations {
601 migrationStatus := &MigrationStatus{
602 Source: &Source{
603 Type: m.Type,
604 Path: m.Source,
605 Version: m.Version,
606 },
607 State: StatePending,
608 }
609 // If versioning is disabled, we can't check the database for applied migrations, so we
610 // assume all migrations are pending.
611 if !p.cfg.disableVersioning {
612 dbResult, err := p.store.GetMigration(ctx, conn, m.Version)
613 if err != nil && !errors.Is(err, database.ErrVersionNotFound) {
614 return nil, err
615 }
616 if dbResult != nil {
617 migrationStatus.State = StateApplied
618 migrationStatus.AppliedAt = dbResult.Timestamp
619 }
620 }
621 status = append(status, migrationStatus)
622 }
623
624 return status, nil
625}
626
627// getDBMaxVersion returns the highest version recorded in the database, regardless of the order in
628// which migrations were applied. conn may be nil, in which case a connection is initialized.
629func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64, retErr error) {
630 if conn == nil {
631 var cleanup func() error
632 var err error
633 conn, cleanup, err = p.initialize(ctx, true)
634 if err != nil {
635 return 0, err
636 }
637 defer func() {
638 retErr = multierr.Append(retErr, cleanup())
639 }()
640 }
641
642 latest, err := p.store.GetLatestVersion(ctx, conn)
643 if err != nil {
644 if errors.Is(err, database.ErrVersionNotFound) {
645 return 0, errMissingZeroVersion
646 }
647 return -1, err
648 }
649 return latest, nil
650}