provider.go

  1package goose
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8	"io/fs"
  9	"math"
 10	"strconv"
 11	"strings"
 12	"sync"
 13
 14	"github.com/pressly/goose/v3/database"
 15	"github.com/pressly/goose/v3/internal/controller"
 16	"github.com/pressly/goose/v3/internal/gooseutil"
 17	"github.com/pressly/goose/v3/internal/sqlparser"
 18	"go.uber.org/multierr"
 19)
 20
 21// Provider is a goose migration provider.
 22type Provider struct {
 23	// mu protects all accesses to the provider and must be held when calling operations on the
 24	// database.
 25	mu sync.Mutex
 26
 27	db               *sql.DB
 28	store            *controller.StoreController
 29	versionTableOnce sync.Once
 30
 31	fsys fs.FS
 32	cfg  config
 33
 34	// migrations are ordered by version in ascending order. This list will never be empty and
 35	// contains all migrations known to the provider.
 36	migrations []*Migration
 37}
 38
 39// NewProvider returns a new goose provider.
 40//
 41// The caller is responsible for matching the database dialect with the database/sql driver. For
 42// example, if the database dialect is "postgres", the database/sql driver could be
 43// github.com/lib/pq or github.com/jackc/pgx. Each dialect has a corresponding [database.Dialect]
 44// constant backed by a default [database.Store] implementation. For more advanced use cases, such
 45// as using a custom table name or supplying a custom store implementation, see [WithStore].
 46//
 47// fsys is the filesystem used to read migration files, but may be nil. Most users will want to use
 48// [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
 49// However, it is possible to use a different "filesystem", such as [embed.FS] or filter out
 50// migrations using [fs.Sub].
 51//
 52// See [ProviderOption] for more information on configuring the provider.
 53//
 54// Unless otherwise specified, all methods on Provider are safe for concurrent use.
 55func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
 56	if db == nil {
 57		return nil, errors.New("db must not be nil")
 58	}
 59	if fsys == nil {
 60		fsys = noopFS{}
 61	}
 62	cfg := config{
 63		registered:      make(map[int64]*Migration),
 64		excludePaths:    make(map[string]bool),
 65		excludeVersions: make(map[int64]bool),
 66		logger:          &stdLogger{},
 67	}
 68	for _, opt := range opts {
 69		if err := opt.apply(&cfg); err != nil {
 70			return nil, err
 71		}
 72	}
 73	// Allow users to specify a custom store implementation, but only if they don't specify a
 74	// dialect. If they specify a dialect, we'll use the default store implementation.
 75	if dialect == "" && cfg.store == nil {
 76		return nil, errors.New("dialect must not be empty")
 77	}
 78	if dialect != "" && cfg.store != nil {
 79		return nil, errors.New("dialect must be empty when using a custom store implementation")
 80	}
 81	var store database.Store
 82	if dialect != "" {
 83		var err error
 84		store, err = database.NewStore(dialect, DefaultTablename)
 85		if err != nil {
 86			return nil, err
 87		}
 88	} else {
 89		store = cfg.store
 90	}
 91	if store.Tablename() == "" {
 92		return nil, errors.New("invalid store implementation: table name must not be empty")
 93	}
 94	return newProvider(db, store, fsys, cfg, registeredGoMigrations /* global */)
 95}
 96
 97func newProvider(
 98	db *sql.DB,
 99	store database.Store,
100	fsys fs.FS,
101	cfg config,
102	global map[int64]*Migration,
103) (*Provider, error) {
104	// Collect migrations from the filesystem and merge with registered migrations.
105	//
106	// Note, we don't parse SQL migrations here. They are parsed lazily when required.
107
108	// feat(mf): we could add a flag to parse SQL migrations eagerly. This would allow us to return
109	// an error if there are any SQL parsing errors. This adds a bit overhead to startup though, so
110	// we should make it optional.
111	filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions)
112	if err != nil {
113		return nil, err
114	}
115	versionToGoMigration := make(map[int64]*Migration)
116	// Add user-registered Go migrations from the provider.
117	for version, m := range cfg.registered {
118		versionToGoMigration[version] = m
119	}
120	// Skip adding global Go migrations if explicitly disabled.
121	if cfg.disableGlobalRegistry {
122		// TODO(mf): let's add a warn-level log here to inform users if len(global) > 0. Would like
123		// to add this once we're on go1.21 and leverage the new slog package.
124	} else {
125		for version, m := range global {
126			if _, ok := versionToGoMigration[version]; ok {
127				return nil, fmt.Errorf("global go migration conflicts with provider-registered go migration with version %d", version)
128			}
129			versionToGoMigration[version] = m
130		}
131	}
132	// At this point we have all registered unique Go migrations (if any). We need to merge them
133	// with SQL migrations from the filesystem.
134	migrations, err := merge(filesystemSources, versionToGoMigration)
135	if err != nil {
136		return nil, err
137	}
138	if len(migrations) == 0 {
139		return nil, ErrNoMigrations
140	}
141	return &Provider{
142		db:         db,
143		fsys:       fsys,
144		cfg:        cfg,
145		store:      controller.NewStoreController(store),
146		migrations: migrations,
147	}, nil
148}
149
150// Status returns the status of all migrations, merging the list of migrations from the database and
151// filesystem. The returned items are ordered by version, in ascending order.
152func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
153	return p.status(ctx)
154}
155
156// HasPending returns true if there are pending migrations to apply, otherwise, it returns false. If
157// out-of-order migrations are disabled, yet some are detected, this method returns an error.
158//
159// Note, this method will not use a SessionLocker if one is configured. This allows callers to check
160// for pending migrations without blocking or being blocked by other operations.
161func (p *Provider) HasPending(ctx context.Context) (bool, error) {
162	return p.hasPending(ctx)
163}
164
165// GetVersions returns the max database version and the target version to migrate to.
166//
167// Note, this method will not use a SessionLocker if one is configured. This allows callers to check
168// for versions without blocking or being blocked by other operations.
169func (p *Provider) GetVersions(ctx context.Context) (current, target int64, err error) {
170	return p.getVersions(ctx)
171}
172
173// GetDBVersion returns the highest version recorded in the database, regardless of the order in
174// which migrations were applied. For example, if migrations were applied out of order (1,4,2,3),
175// this method returns 4. If no migrations have been applied, it returns 0.
176func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
177	if p.cfg.disableVersioning {
178		return -1, errors.New("getting database version not supported when versioning is disabled")
179	}
180	return p.getDBMaxVersion(ctx, nil)
181}
182
183// ListSources returns a list of all migration sources known to the provider, sorted in ascending
184// order by version. The path field may be empty for manually registered migrations, such as Go
185// migrations registered using the [WithGoMigrations] option.
186func (p *Provider) ListSources() []*Source {
187	sources := make([]*Source, 0, len(p.migrations))
188	for _, m := range p.migrations {
189		sources = append(sources, &Source{
190			Type:    m.Type,
191			Path:    m.Source,
192			Version: m.Version,
193		})
194	}
195	return sources
196}
197
198// Ping attempts to ping the database to verify a connection is available.
199func (p *Provider) Ping(ctx context.Context) error {
200	return p.db.PingContext(ctx)
201}
202
203// Close closes the database connection initially supplied to the provider.
204func (p *Provider) Close() error {
205	return p.db.Close()
206}
207
208// ApplyVersion applies exactly one migration for the specified version. If there is no migration
209// available for the specified version, this method returns [ErrVersionNotFound]. If the migration
210// has already been applied, this method returns [ErrAlreadyApplied].
211//
212// The direction parameter determines the migration direction: true for up migration and false for
213// down migration.
214func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
215	res, err := p.apply(ctx, version, direction)
216	if err != nil {
217		return nil, err
218	}
219	// This should never happen, we must return exactly one result.
220	if len(res) != 1 {
221		versions := make([]string, 0, len(res))
222		for _, r := range res {
223			versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
224		}
225		return nil, fmt.Errorf(
226			"unexpected number of migrations applied running apply, expecting exactly one result: %v",
227			strings.Join(versions, ","),
228		)
229	}
230	return res[0], nil
231}
232
233// Up applies all pending migrations. If there are no new migrations to apply, this method returns
234// empty list and nil error.
235func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
236	hasPending, err := p.HasPending(ctx)
237	if err != nil {
238		return nil, err
239	}
240	if !hasPending {
241		return nil, nil
242	}
243	return p.up(ctx, false, math.MaxInt64)
244}
245
246// UpByOne applies the next pending migration. If there is no next migration to apply, this method
247// returns [ErrNoNextVersion].
248func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
249	hasPending, err := p.HasPending(ctx)
250	if err != nil {
251		return nil, err
252	}
253	if !hasPending {
254		return nil, ErrNoNextVersion
255	}
256	res, err := p.up(ctx, true, math.MaxInt64)
257	if err != nil {
258		return nil, err
259	}
260	if len(res) == 0 {
261		return nil, ErrNoNextVersion
262	}
263	// This should never happen, we must return exactly one result.
264	if len(res) != 1 {
265		versions := make([]string, 0, len(res))
266		for _, r := range res {
267			versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
268		}
269		return nil, fmt.Errorf(
270			"unexpected number of migrations applied running up-by-one, expecting exactly one result: %v",
271			strings.Join(versions, ","),
272		)
273	}
274	return res[0], nil
275}
276
277// UpTo applies all pending migrations up to, and including, the specified version. If there are no
278// migrations to apply, this method returns empty list and nil error.
279//
280// For example, if there are three new migrations (9,10,11) and the current database version is 8
281// with a requested version of 10, only versions 9,10 will be applied.
282func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
283	hasPending, err := p.HasPending(ctx)
284	if err != nil {
285		return nil, err
286	}
287	if !hasPending {
288		return nil, nil
289	}
290	return p.up(ctx, false, version)
291}
292
293// Down rolls back the most recently applied migration. If there are no migrations to rollback, this
294// method returns [ErrNoNextVersion].
295//
296// Note, migrations are rolled back in the order they were applied. And not in the reverse order of
297// the migration version. This only applies in scenarios where migrations are allowed to be applied
298// out of order.
299func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
300	res, err := p.down(ctx, true, 0)
301	if err != nil {
302		return nil, err
303	}
304	if len(res) == 0 {
305		return nil, ErrNoNextVersion
306	}
307	// This should never happen, we must return exactly one result.
308	if len(res) != 1 {
309		versions := make([]string, 0, len(res))
310		for _, r := range res {
311			versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
312		}
313		return nil, fmt.Errorf(
314			"unexpected number of migrations applied running down, expecting exactly one result: %v",
315			strings.Join(versions, ","),
316		)
317	}
318	return res[0], nil
319}
320
321// DownTo rolls back all migrations down to, but not including, the specified version.
322//
323// For example, if the current database version is 11,10,9... and the requested version is 9, only
324// migrations 11, 10 will be rolled back.
325//
326// Note, migrations are rolled back in the order they were applied. And not in the reverse order of
327// the migration version. This only applies in scenarios where migrations are allowed to be applied
328// out of order.
329func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
330	if version < 0 {
331		return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
332	}
333	return p.down(ctx, false, version)
334}
335
336// *** Internal methods ***
337
338func (p *Provider) up(
339	ctx context.Context,
340	byOne bool,
341	version int64,
342) (_ []*MigrationResult, retErr error) {
343	if version < 1 {
344		return nil, errInvalidVersion
345	}
346	conn, cleanup, err := p.initialize(ctx, true)
347	if err != nil {
348		return nil, fmt.Errorf("failed to initialize: %w", err)
349	}
350	defer func() {
351		retErr = multierr.Append(retErr, cleanup())
352	}()
353
354	if len(p.migrations) == 0 {
355		return nil, nil
356	}
357	var apply []*Migration
358	if p.cfg.disableVersioning {
359		if byOne {
360			return nil, errors.New("up-by-one not supported when versioning is disabled")
361		}
362		apply = p.migrations
363	} else {
364		// optimize(mf): Listing all migrations from the database isn't great.
365		//
366		// The ideal implementation would be to query for the current max version and then apply
367		// migrations greater than that version. However, a nice property of the current
368		// implementation is that we can make stronger guarantees about unapplied migrations.
369		//
370		// In cases where users do not use out-of-order migrations, we want to surface an error if
371		// there are older unapplied migrations. See https://github.com/pressly/goose/issues/761 for
372		// more details.
373		//
374		// And in cases where users do use out-of-order migrations, we need to build a list of older
375		// migrations that need to be applied, so we need to query for all migrations anyways.
376		dbMigrations, err := p.store.ListMigrations(ctx, conn)
377		if err != nil {
378			return nil, err
379		}
380		if len(dbMigrations) == 0 {
381			return nil, errMissingZeroVersion
382		}
383		versions, err := gooseutil.UpVersions(
384			getVersionsFromMigrations(p.migrations),     // fsys versions
385			getVersionsFromListMigrations(dbMigrations), // db versions
386			version,
387			p.cfg.allowMissing,
388		)
389		if err != nil {
390			return nil, err
391		}
392		for _, v := range versions {
393			m, err := p.getMigration(v)
394			if err != nil {
395				return nil, err
396			}
397			apply = append(apply, m)
398		}
399	}
400	return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, byOne)
401}
402
403func (p *Provider) down(
404	ctx context.Context,
405	byOne bool,
406	version int64,
407) (_ []*MigrationResult, retErr error) {
408	conn, cleanup, err := p.initialize(ctx, true)
409	if err != nil {
410		return nil, fmt.Errorf("failed to initialize: %w", err)
411	}
412	defer func() {
413		retErr = multierr.Append(retErr, cleanup())
414	}()
415
416	if len(p.migrations) == 0 {
417		return nil, nil
418	}
419	if p.cfg.disableVersioning {
420		var downMigrations []*Migration
421		if byOne {
422			last := p.migrations[len(p.migrations)-1]
423			downMigrations = []*Migration{last}
424		} else {
425			downMigrations = p.migrations
426		}
427		return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, byOne)
428	}
429	dbMigrations, err := p.store.ListMigrations(ctx, conn)
430	if err != nil {
431		return nil, err
432	}
433	if len(dbMigrations) == 0 {
434		return nil, errMissingZeroVersion
435	}
436	// We never migrate the zero version down.
437	if dbMigrations[0].Version == 0 {
438		p.printf("no migrations to run, current version: 0")
439		return nil, nil
440	}
441	var apply []*Migration
442	for _, dbMigration := range dbMigrations {
443		if dbMigration.Version <= version {
444			break
445		}
446		m, err := p.getMigration(dbMigration.Version)
447		if err != nil {
448			return nil, err
449		}
450		apply = append(apply, m)
451	}
452	return p.runMigrations(ctx, conn, apply, sqlparser.DirectionDown, byOne)
453}
454
455func (p *Provider) apply(
456	ctx context.Context,
457	version int64,
458	direction bool,
459) (_ []*MigrationResult, retErr error) {
460	if version < 1 {
461		return nil, errInvalidVersion
462	}
463	m, err := p.getMigration(version)
464	if err != nil {
465		return nil, err
466	}
467	conn, cleanup, err := p.initialize(ctx, true)
468	if err != nil {
469		return nil, fmt.Errorf("failed to initialize: %w", err)
470	}
471	defer func() {
472		retErr = multierr.Append(retErr, cleanup())
473	}()
474
475	result, err := p.store.GetMigration(ctx, conn, version)
476	if err != nil && !errors.Is(err, database.ErrVersionNotFound) {
477		return nil, err
478	}
479	// There are a few states here:
480	//  1. direction is up
481	//    a. migration is applied, this is an error (ErrAlreadyApplied)
482	//    b. migration is not applied, apply it
483	if direction && result != nil {
484		return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied)
485	}
486	//  2. direction is down
487	//    a. migration is applied, rollback
488	//    b. migration is not applied, this is an error (ErrNotApplied)
489	if !direction && result == nil {
490		return nil, fmt.Errorf("version %d: %w", version, ErrNotApplied)
491	}
492	d := sqlparser.DirectionDown
493	if direction {
494		d = sqlparser.DirectionUp
495	}
496	return p.runMigrations(ctx, conn, []*Migration{m}, d, true)
497}
498
499func (p *Provider) getVersions(ctx context.Context) (current, target int64, retErr error) {
500	conn, cleanup, err := p.initialize(ctx, false)
501	if err != nil {
502		return -1, -1, fmt.Errorf("failed to initialize: %w", err)
503	}
504	defer func() {
505		retErr = multierr.Append(retErr, cleanup())
506	}()
507
508	target = p.migrations[len(p.migrations)-1].Version
509
510	// If versioning is disabled, we always have pending migrations and the target version is the
511	// last migration.
512	if p.cfg.disableVersioning {
513		return -1, target, nil
514	}
515
516	current, err = p.store.GetLatestVersion(ctx, conn)
517	if err != nil {
518		if errors.Is(err, database.ErrVersionNotFound) {
519			return -1, target, errMissingZeroVersion
520		}
521		return -1, target, err
522	}
523	return current, target, nil
524}
525
526func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
527	conn, cleanup, err := p.initialize(ctx, false)
528	if err != nil {
529		return false, fmt.Errorf("failed to initialize: %w", err)
530	}
531	defer func() {
532		retErr = multierr.Append(retErr, cleanup())
533	}()
534
535	// If versioning is disabled, we always have pending migrations.
536	if p.cfg.disableVersioning {
537		return true, nil
538	}
539
540	// List all migrations from the database. Careful, optimizations here can lead to subtle bugs.
541	// We have 2 important cases to consider:
542	//
543	//  1.  Users have enabled out-of-order migrations, in which case we need to check if any
544	//      migrations are missing and report that there are pending migrations. Do not surface an
545	//      error because this is a valid state.
546	//
547	//  2.  Users have disabled out-of-order migrations (default), in which case we need to check if all
548	//      migrations have been applied. We cannot check for the highest applied version because we lose the
549	//      ability to surface an error if an out-of-order migration was introduced. It would be silently
550	//      ignored and the user would not know that they have unapplied migrations.
551	//
552	//      Maybe we could consider adding a flag to the provider such as IgnoreMissing, which would
553	//      allow silently ignoring missing migrations. This would be useful for users that have built
554	//      checks that prevent missing migrations from being introduced.
555
556	dbMigrations, err := p.store.ListMigrations(ctx, conn)
557	if err != nil {
558		return false, err
559	}
560	apply, err := gooseutil.UpVersions(
561		getVersionsFromMigrations(p.migrations),     // fsys versions
562		getVersionsFromListMigrations(dbMigrations), // db versions
563		math.MaxInt64,
564		p.cfg.allowMissing,
565	)
566	if err != nil {
567		return false, err
568	}
569	return len(apply) > 0, nil
570}
571
572func getVersionsFromMigrations(in []*Migration) []int64 {
573	out := make([]int64, 0, len(in))
574	for _, m := range in {
575		out = append(out, m.Version)
576	}
577	return out
578
579}
580
581func getVersionsFromListMigrations(in []*database.ListMigrationsResult) []int64 {
582	out := make([]int64, 0, len(in))
583	for _, m := range in {
584		out = append(out, m.Version)
585	}
586	return out
587
588}
589
590func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
591	conn, cleanup, err := p.initialize(ctx, true)
592	if err != nil {
593		return nil, fmt.Errorf("failed to initialize: %w", err)
594	}
595	defer func() {
596		retErr = multierr.Append(retErr, cleanup())
597	}()
598
599	status := make([]*MigrationStatus, 0, len(p.migrations))
600	for _, m := range p.migrations {
601		migrationStatus := &MigrationStatus{
602			Source: &Source{
603				Type:    m.Type,
604				Path:    m.Source,
605				Version: m.Version,
606			},
607			State: StatePending,
608		}
609		// If versioning is disabled, we can't check the database for applied migrations, so we
610		// assume all migrations are pending.
611		if !p.cfg.disableVersioning {
612			dbResult, err := p.store.GetMigration(ctx, conn, m.Version)
613			if err != nil && !errors.Is(err, database.ErrVersionNotFound) {
614				return nil, err
615			}
616			if dbResult != nil {
617				migrationStatus.State = StateApplied
618				migrationStatus.AppliedAt = dbResult.Timestamp
619			}
620		}
621		status = append(status, migrationStatus)
622	}
623
624	return status, nil
625}
626
627// getDBMaxVersion returns the highest version recorded in the database, regardless of the order in
628// which migrations were applied. conn may be nil, in which case a connection is initialized.
629func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64, retErr error) {
630	if conn == nil {
631		var cleanup func() error
632		var err error
633		conn, cleanup, err = p.initialize(ctx, true)
634		if err != nil {
635			return 0, err
636		}
637		defer func() {
638			retErr = multierr.Append(retErr, cleanup())
639		}()
640	}
641
642	latest, err := p.store.GetLatestVersion(ctx, conn)
643	if err != nil {
644		if errors.Is(err, database.ErrVersionNotFound) {
645			return 0, errMissingZeroVersion
646		}
647		return -1, err
648	}
649	return latest, nil
650}