migrate.go

  1package goose
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8	"go.uber.org/multierr"
  9	"io/fs"
 10	"math"
 11	"path"
 12	"sort"
 13	"strings"
 14	"time"
 15)
 16
 17var (
 18	// ErrNoMigrationFiles when no migration files have been found.
 19	ErrNoMigrationFiles = errors.New("no migration files found")
 20	// ErrNoCurrentVersion when a current migration version is not found.
 21	ErrNoCurrentVersion = errors.New("no current version found")
 22	// ErrNoNextVersion when the next migration version is not found.
 23	ErrNoNextVersion = errors.New("no next version found")
 24	// MaxVersion is the maximum allowed version.
 25	MaxVersion int64 = math.MaxInt64
 26)
 27
 28// Migrations slice.
 29type Migrations []*Migration
 30
 31// helpers so we can use pkg sort
 32func (ms Migrations) Len() int      { return len(ms) }
 33func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
 34func (ms Migrations) Less(i, j int) bool {
 35	if ms[i].Version == ms[j].Version {
 36		panic(fmt.Sprintf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source))
 37	}
 38	return ms[i].Version < ms[j].Version
 39}
 40
 41// Current gets the current migration.
 42func (ms Migrations) Current(current int64) (*Migration, error) {
 43	for i, migration := range ms {
 44		if migration.Version == current {
 45			return ms[i], nil
 46		}
 47	}
 48
 49	return nil, ErrNoCurrentVersion
 50}
 51
 52// Next gets the next migration.
 53func (ms Migrations) Next(current int64) (*Migration, error) {
 54	for i, migration := range ms {
 55		if migration.Version > current {
 56			return ms[i], nil
 57		}
 58	}
 59
 60	return nil, ErrNoNextVersion
 61}
 62
 63// Previous : Get the previous migration.
 64func (ms Migrations) Previous(current int64) (*Migration, error) {
 65	for i := len(ms) - 1; i >= 0; i-- {
 66		if ms[i].Version < current {
 67			return ms[i], nil
 68		}
 69	}
 70
 71	return nil, ErrNoNextVersion
 72}
 73
 74// Last gets the last migration.
 75func (ms Migrations) Last() (*Migration, error) {
 76	if len(ms) == 0 {
 77		return nil, ErrNoNextVersion
 78	}
 79
 80	return ms[len(ms)-1], nil
 81}
 82
 83// Versioned gets versioned migrations.
 84func (ms Migrations) versioned() (Migrations, error) {
 85	var migrations Migrations
 86
 87	// assume that the user will never have more than 19700101000000 migrations
 88	for _, m := range ms {
 89		// parse version as timestamp
 90		versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version))
 91
 92		if versionTime.Before(time.Unix(0, 0)) || err != nil {
 93			migrations = append(migrations, m)
 94		}
 95	}
 96
 97	return migrations, nil
 98}
 99
100// Timestamped gets the timestamped migrations.
101func (ms Migrations) timestamped() (Migrations, error) {
102	var migrations Migrations
103
104	// assume that the user will never have more than 19700101000000 migrations
105	for _, m := range ms {
106		// parse version as timestamp
107		versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version))
108		if err != nil {
109			// probably not a timestamp
110			continue
111		}
112
113		if versionTime.After(time.Unix(0, 0)) {
114			migrations = append(migrations, m)
115		}
116	}
117	return migrations, nil
118}
119
120func (ms Migrations) String() string {
121	str := ""
122	for _, m := range ms {
123		str += fmt.Sprintln(m)
124	}
125	return str
126}
127
128func collectMigrationsFS(
129	fsys fs.FS,
130	dirpath string,
131	current, target int64,
132	registered map[int64]*Migration,
133) (Migrations, error) {
134	if _, err := fs.Stat(fsys, dirpath); err != nil {
135		if errors.Is(err, fs.ErrNotExist) {
136			return nil, fmt.Errorf("%s directory does not exist", dirpath)
137		}
138		return nil, err
139	}
140	var migrations Migrations
141	// SQL migration files.
142	sqlMigrationFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.sql"))
143	if err != nil {
144		return nil, err
145	}
146	for _, file := range sqlMigrationFiles {
147		v, err := NumericComponent(file)
148		if err != nil {
149			return nil, fmt.Errorf("could not parse SQL migration file %q: %w", file, err)
150		}
151		if versionFilter(v, current, target) {
152			migrations = append(migrations, &Migration{
153				Version:  v,
154				Next:     -1,
155				Previous: -1,
156				Source:   file,
157			})
158		}
159	}
160	// Go migration files.
161	goMigrations, err := collectGoMigrations(fsys, dirpath, registered, current, target)
162	if err != nil {
163		return nil, err
164	}
165	migrations = append(migrations, goMigrations...)
166	if len(migrations) == 0 {
167		return nil, ErrNoMigrationFiles
168	}
169	return sortAndConnectMigrations(migrations), nil
170}
171
172// CollectMigrations returns all the valid looking migration scripts in the
173// migrations folder and go func registry, and key them by version.
174func CollectMigrations(dirpath string, current, target int64) (Migrations, error) {
175	return collectMigrationsFS(baseFS, dirpath, current, target, registeredGoMigrations)
176}
177
178func sortAndConnectMigrations(migrations Migrations) Migrations {
179	sort.Sort(migrations)
180
181	// now that we're sorted in the appropriate direction,
182	// populate next and previous for each migration
183	for i, m := range migrations {
184		prev := int64(-1)
185		if i > 0 {
186			prev = migrations[i-1].Version
187			migrations[i-1].Next = m.Version
188		}
189		migrations[i].Previous = prev
190	}
191
192	return migrations
193}
194
195func versionFilter(v, current, target int64) bool {
196	if target > current {
197		return v > current && v <= target
198	}
199	if target < current {
200		return v <= current && v > target
201	}
202	return false
203}
204
205// EnsureDBVersion retrieves the current version for this DB.
206// Create and initialize the DB version table if it doesn't exist.
207func EnsureDBVersion(db *sql.DB) (int64, error) {
208	ctx := context.Background()
209	return EnsureDBVersionContext(ctx, db)
210}
211
212// EnsureDBVersionContext retrieves the current version for this DB.
213// Create and initialize the DB version table if it doesn't exist.
214func EnsureDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
215	dbMigrations, err := store.ListMigrations(ctx, db, TableName())
216	if err != nil {
217		createErr := createVersionTable(ctx, db)
218		if createErr != nil {
219			return 0, multierr.Append(err, createErr)
220		}
221		return 0, nil
222	}
223	// The most recent record for each migration specifies
224	// whether it has been applied or rolled back.
225	// The first version we find that has been applied is the current version.
226	//
227	// TODO(mf): for historic reasons, we continue to use the is_applied column,
228	// but at some point we need to deprecate this logic and ideally remove
229	// this column.
230	//
231	// For context, see:
232	// https://github.com/pressly/goose/pull/131#pullrequestreview-178409168
233	//
234	// The dbMigrations list is expected to be ordered by descending ID. But
235	// in the future we should be able to query the last record only.
236	skipLookup := make(map[int64]struct{})
237	for _, m := range dbMigrations {
238		// Have we already marked this version to be skipped?
239		if _, ok := skipLookup[m.VersionID]; ok {
240			continue
241		}
242		// If version has been applied we are done.
243		if m.IsApplied {
244			return m.VersionID, nil
245		}
246		// Latest version of migration has not been applied.
247		skipLookup[m.VersionID] = struct{}{}
248	}
249	return 0, ErrNoNextVersion
250}
251
252// createVersionTable creates the db version table and inserts the
253// initial 0 value into it.
254func createVersionTable(ctx context.Context, db *sql.DB) error {
255	txn, err := db.BeginTx(ctx, nil)
256	if err != nil {
257		return err
258	}
259	if err := store.CreateVersionTable(ctx, txn, TableName()); err != nil {
260		_ = txn.Rollback()
261		return err
262	}
263	if err := store.InsertVersion(ctx, txn, TableName(), 0); err != nil {
264		_ = txn.Rollback()
265		return err
266	}
267	return txn.Commit()
268}
269
270// GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error.
271func GetDBVersion(db *sql.DB) (int64, error) {
272	ctx := context.Background()
273	return GetDBVersionContext(ctx, db)
274}
275
276// GetDBVersionContext is an alias for EnsureDBVersion, but returns -1 in error.
277func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
278	version, err := EnsureDBVersionContext(ctx, db)
279	if err != nil {
280		return -1, err
281	}
282
283	return version, nil
284}
285
286// collectGoMigrations collects Go migrations from the filesystem and merges them with registered
287// migrations.
288//
289// If Go migrations have been registered globally, with [goose.AddNamedMigration...], but there are
290// no corresponding .go files in the filesystem, add them to the migrations slice.
291//
292// If Go migrations have been registered, and there are .go files in the filesystem dirpath, ONLY
293// include those in the migrations slices.
294//
295// Lastly, if there are .go files in the filesystem but they have not been registered, raise an
296// error. This is to prevent users from accidentally adding valid looking Go files to the migrations
297// folder without registering them.
298func collectGoMigrations(
299	fsys fs.FS,
300	dirpath string,
301	registeredGoMigrations map[int64]*Migration,
302	current, target int64,
303) (Migrations, error) {
304	// Sanity check registered migrations have the correct version prefix.
305	for _, m := range registeredGoMigrations {
306		if _, err := NumericComponent(m.Source); err != nil {
307			return nil, fmt.Errorf("could not parse go migration file %s: %w", m.Source, err)
308		}
309	}
310	goFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.go"))
311	if err != nil {
312		return nil, err
313	}
314	// If there are no Go files in the filesystem and no registered Go migrations, return early.
315	if len(goFiles) == 0 && len(registeredGoMigrations) == 0 {
316		return nil, nil
317	}
318	type source struct {
319		fullpath string
320		version  int64
321	}
322	// Find all Go files that have a version prefix and are within the requested range.
323	var sources []source
324	for _, fullpath := range goFiles {
325		v, err := NumericComponent(fullpath)
326		if err != nil {
327			continue // Skip any files that don't have version prefix.
328		}
329		if strings.HasSuffix(fullpath, "_test.go") {
330			continue // Skip Go test files.
331		}
332		if versionFilter(v, current, target) {
333			sources = append(sources, source{
334				fullpath: fullpath,
335				version:  v,
336			})
337		}
338	}
339	var (
340		migrations Migrations
341	)
342	if len(sources) > 0 {
343		for _, s := range sources {
344			migration, ok := registeredGoMigrations[s.version]
345			if ok {
346				migrations = append(migrations, migration)
347			} else {
348				// TODO(mf): something that bothers me about this implementation is it will be
349				// lazily evaluated and the error will only be raised if the user tries to run the
350				// migration. It would be better to raise an error much earlier in the process.
351				migrations = append(migrations, &Migration{
352					Version:    s.version,
353					Next:       -1,
354					Previous:   -1,
355					Source:     s.fullpath,
356					Registered: false,
357				})
358			}
359		}
360	} else {
361		// Some users may register Go migrations manually via AddNamedMigration_ functions but not
362		// provide the corresponding .go files in the filesystem. In this case, we include them
363		// wholesale in the migrations slice.
364		//
365		// This is a valid use case because users may want to build a custom binary that only embeds
366		// the SQL migration files and some other mechanism for registering Go migrations.
367		for _, migration := range registeredGoMigrations {
368			v, err := NumericComponent(migration.Source)
369			if err != nil {
370				return nil, fmt.Errorf("could not parse go migration file %s: %w", migration.Source, err)
371			}
372			if versionFilter(v, current, target) {
373				migrations = append(migrations, migration)
374			}
375		}
376	}
377	return migrations, nil
378}