1package goose
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "go.uber.org/multierr"
9 "io/fs"
10 "math"
11 "path"
12 "sort"
13 "strings"
14 "time"
15)
16
17var (
18 // ErrNoMigrationFiles when no migration files have been found.
19 ErrNoMigrationFiles = errors.New("no migration files found")
20 // ErrNoCurrentVersion when a current migration version is not found.
21 ErrNoCurrentVersion = errors.New("no current version found")
22 // ErrNoNextVersion when the next migration version is not found.
23 ErrNoNextVersion = errors.New("no next version found")
24 // MaxVersion is the maximum allowed version.
25 MaxVersion int64 = math.MaxInt64
26)
27
28// Migrations slice.
29type Migrations []*Migration
30
31// helpers so we can use pkg sort
32func (ms Migrations) Len() int { return len(ms) }
33func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
34func (ms Migrations) Less(i, j int) bool {
35 if ms[i].Version == ms[j].Version {
36 panic(fmt.Sprintf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source))
37 }
38 return ms[i].Version < ms[j].Version
39}
40
41// Current gets the current migration.
42func (ms Migrations) Current(current int64) (*Migration, error) {
43 for i, migration := range ms {
44 if migration.Version == current {
45 return ms[i], nil
46 }
47 }
48
49 return nil, ErrNoCurrentVersion
50}
51
52// Next gets the next migration.
53func (ms Migrations) Next(current int64) (*Migration, error) {
54 for i, migration := range ms {
55 if migration.Version > current {
56 return ms[i], nil
57 }
58 }
59
60 return nil, ErrNoNextVersion
61}
62
63// Previous : Get the previous migration.
64func (ms Migrations) Previous(current int64) (*Migration, error) {
65 for i := len(ms) - 1; i >= 0; i-- {
66 if ms[i].Version < current {
67 return ms[i], nil
68 }
69 }
70
71 return nil, ErrNoNextVersion
72}
73
74// Last gets the last migration.
75func (ms Migrations) Last() (*Migration, error) {
76 if len(ms) == 0 {
77 return nil, ErrNoNextVersion
78 }
79
80 return ms[len(ms)-1], nil
81}
82
83// Versioned gets versioned migrations.
84func (ms Migrations) versioned() (Migrations, error) {
85 var migrations Migrations
86
87 // assume that the user will never have more than 19700101000000 migrations
88 for _, m := range ms {
89 // parse version as timestamp
90 versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version))
91
92 if versionTime.Before(time.Unix(0, 0)) || err != nil {
93 migrations = append(migrations, m)
94 }
95 }
96
97 return migrations, nil
98}
99
100// Timestamped gets the timestamped migrations.
101func (ms Migrations) timestamped() (Migrations, error) {
102 var migrations Migrations
103
104 // assume that the user will never have more than 19700101000000 migrations
105 for _, m := range ms {
106 // parse version as timestamp
107 versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version))
108 if err != nil {
109 // probably not a timestamp
110 continue
111 }
112
113 if versionTime.After(time.Unix(0, 0)) {
114 migrations = append(migrations, m)
115 }
116 }
117 return migrations, nil
118}
119
120func (ms Migrations) String() string {
121 str := ""
122 for _, m := range ms {
123 str += fmt.Sprintln(m)
124 }
125 return str
126}
127
128func collectMigrationsFS(
129 fsys fs.FS,
130 dirpath string,
131 current, target int64,
132 registered map[int64]*Migration,
133) (Migrations, error) {
134 if _, err := fs.Stat(fsys, dirpath); err != nil {
135 if errors.Is(err, fs.ErrNotExist) {
136 return nil, fmt.Errorf("%s directory does not exist", dirpath)
137 }
138 return nil, err
139 }
140 var migrations Migrations
141 // SQL migration files.
142 sqlMigrationFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.sql"))
143 if err != nil {
144 return nil, err
145 }
146 for _, file := range sqlMigrationFiles {
147 v, err := NumericComponent(file)
148 if err != nil {
149 return nil, fmt.Errorf("could not parse SQL migration file %q: %w", file, err)
150 }
151 if versionFilter(v, current, target) {
152 migrations = append(migrations, &Migration{
153 Version: v,
154 Next: -1,
155 Previous: -1,
156 Source: file,
157 })
158 }
159 }
160 // Go migration files.
161 goMigrations, err := collectGoMigrations(fsys, dirpath, registered, current, target)
162 if err != nil {
163 return nil, err
164 }
165 migrations = append(migrations, goMigrations...)
166 if len(migrations) == 0 {
167 return nil, ErrNoMigrationFiles
168 }
169 return sortAndConnectMigrations(migrations), nil
170}
171
172// CollectMigrations returns all the valid looking migration scripts in the
173// migrations folder and go func registry, and key them by version.
174func CollectMigrations(dirpath string, current, target int64) (Migrations, error) {
175 return collectMigrationsFS(baseFS, dirpath, current, target, registeredGoMigrations)
176}
177
178func sortAndConnectMigrations(migrations Migrations) Migrations {
179 sort.Sort(migrations)
180
181 // now that we're sorted in the appropriate direction,
182 // populate next and previous for each migration
183 for i, m := range migrations {
184 prev := int64(-1)
185 if i > 0 {
186 prev = migrations[i-1].Version
187 migrations[i-1].Next = m.Version
188 }
189 migrations[i].Previous = prev
190 }
191
192 return migrations
193}
194
195func versionFilter(v, current, target int64) bool {
196 if target > current {
197 return v > current && v <= target
198 }
199 if target < current {
200 return v <= current && v > target
201 }
202 return false
203}
204
205// EnsureDBVersion retrieves the current version for this DB.
206// Create and initialize the DB version table if it doesn't exist.
207func EnsureDBVersion(db *sql.DB) (int64, error) {
208 ctx := context.Background()
209 return EnsureDBVersionContext(ctx, db)
210}
211
212// EnsureDBVersionContext retrieves the current version for this DB.
213// Create and initialize the DB version table if it doesn't exist.
214func EnsureDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
215 dbMigrations, err := store.ListMigrations(ctx, db, TableName())
216 if err != nil {
217 createErr := createVersionTable(ctx, db)
218 if createErr != nil {
219 return 0, multierr.Append(err, createErr)
220 }
221 return 0, nil
222 }
223 // The most recent record for each migration specifies
224 // whether it has been applied or rolled back.
225 // The first version we find that has been applied is the current version.
226 //
227 // TODO(mf): for historic reasons, we continue to use the is_applied column,
228 // but at some point we need to deprecate this logic and ideally remove
229 // this column.
230 //
231 // For context, see:
232 // https://github.com/pressly/goose/pull/131#pullrequestreview-178409168
233 //
234 // The dbMigrations list is expected to be ordered by descending ID. But
235 // in the future we should be able to query the last record only.
236 skipLookup := make(map[int64]struct{})
237 for _, m := range dbMigrations {
238 // Have we already marked this version to be skipped?
239 if _, ok := skipLookup[m.VersionID]; ok {
240 continue
241 }
242 // If version has been applied we are done.
243 if m.IsApplied {
244 return m.VersionID, nil
245 }
246 // Latest version of migration has not been applied.
247 skipLookup[m.VersionID] = struct{}{}
248 }
249 return 0, ErrNoNextVersion
250}
251
252// createVersionTable creates the db version table and inserts the
253// initial 0 value into it.
254func createVersionTable(ctx context.Context, db *sql.DB) error {
255 txn, err := db.BeginTx(ctx, nil)
256 if err != nil {
257 return err
258 }
259 if err := store.CreateVersionTable(ctx, txn, TableName()); err != nil {
260 _ = txn.Rollback()
261 return err
262 }
263 if err := store.InsertVersion(ctx, txn, TableName(), 0); err != nil {
264 _ = txn.Rollback()
265 return err
266 }
267 return txn.Commit()
268}
269
270// GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error.
271func GetDBVersion(db *sql.DB) (int64, error) {
272 ctx := context.Background()
273 return GetDBVersionContext(ctx, db)
274}
275
276// GetDBVersionContext is an alias for EnsureDBVersion, but returns -1 in error.
277func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
278 version, err := EnsureDBVersionContext(ctx, db)
279 if err != nil {
280 return -1, err
281 }
282
283 return version, nil
284}
285
286// collectGoMigrations collects Go migrations from the filesystem and merges them with registered
287// migrations.
288//
289// If Go migrations have been registered globally, with [goose.AddNamedMigration...], but there are
290// no corresponding .go files in the filesystem, add them to the migrations slice.
291//
292// If Go migrations have been registered, and there are .go files in the filesystem dirpath, ONLY
293// include those in the migrations slices.
294//
295// Lastly, if there are .go files in the filesystem but they have not been registered, raise an
296// error. This is to prevent users from accidentally adding valid looking Go files to the migrations
297// folder without registering them.
298func collectGoMigrations(
299 fsys fs.FS,
300 dirpath string,
301 registeredGoMigrations map[int64]*Migration,
302 current, target int64,
303) (Migrations, error) {
304 // Sanity check registered migrations have the correct version prefix.
305 for _, m := range registeredGoMigrations {
306 if _, err := NumericComponent(m.Source); err != nil {
307 return nil, fmt.Errorf("could not parse go migration file %s: %w", m.Source, err)
308 }
309 }
310 goFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.go"))
311 if err != nil {
312 return nil, err
313 }
314 // If there are no Go files in the filesystem and no registered Go migrations, return early.
315 if len(goFiles) == 0 && len(registeredGoMigrations) == 0 {
316 return nil, nil
317 }
318 type source struct {
319 fullpath string
320 version int64
321 }
322 // Find all Go files that have a version prefix and are within the requested range.
323 var sources []source
324 for _, fullpath := range goFiles {
325 v, err := NumericComponent(fullpath)
326 if err != nil {
327 continue // Skip any files that don't have version prefix.
328 }
329 if strings.HasSuffix(fullpath, "_test.go") {
330 continue // Skip Go test files.
331 }
332 if versionFilter(v, current, target) {
333 sources = append(sources, source{
334 fullpath: fullpath,
335 version: v,
336 })
337 }
338 }
339 var (
340 migrations Migrations
341 )
342 if len(sources) > 0 {
343 for _, s := range sources {
344 migration, ok := registeredGoMigrations[s.version]
345 if ok {
346 migrations = append(migrations, migration)
347 } else {
348 // TODO(mf): something that bothers me about this implementation is it will be
349 // lazily evaluated and the error will only be raised if the user tries to run the
350 // migration. It would be better to raise an error much earlier in the process.
351 migrations = append(migrations, &Migration{
352 Version: s.version,
353 Next: -1,
354 Previous: -1,
355 Source: s.fullpath,
356 Registered: false,
357 })
358 }
359 }
360 } else {
361 // Some users may register Go migrations manually via AddNamedMigration_ functions but not
362 // provide the corresponding .go files in the filesystem. In this case, we include them
363 // wholesale in the migrations slice.
364 //
365 // This is a valid use case because users may want to build a custom binary that only embeds
366 // the SQL migration files and some other mechanism for registering Go migrations.
367 for _, migration := range registeredGoMigrations {
368 v, err := NumericComponent(migration.Source)
369 if err != nil {
370 return nil, fmt.Errorf("could not parse go migration file %s: %w", migration.Source, err)
371 }
372 if versionFilter(v, current, target) {
373 migrations = append(migrations, migration)
374 }
375 }
376 }
377 return migrations, nil
378}