1package goose
2
3import (
4 "errors"
5 "fmt"
6 "io/fs"
7 "path/filepath"
8 "sort"
9 "strings"
10)
11
12// fileSources represents a collection of migration files on the filesystem.
13type fileSources struct {
14 sqlSources []Source
15 goSources []Source
16}
17
18// collectFilesystemSources scans the file system for migration files that have a numeric prefix
19// (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may
20// be nil, in which case an empty fileSources is returned.
21//
22// If strict is true, then any error parsing the numeric component of the filename will result in an
23// error. The file is skipped otherwise.
24//
25// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects
26// migration sources from the filesystem.
27func collectFilesystemSources(
28 fsys fs.FS,
29 strict bool,
30 excludePaths map[string]bool,
31 excludeVersions map[int64]bool,
32) (*fileSources, error) {
33 if fsys == nil {
34 return new(fileSources), nil
35 }
36 sources := new(fileSources)
37 versionToBaseLookup := make(map[int64]string) // map[version]filepath.Base(fullpath)
38 for _, pattern := range []string{
39 "*.sql",
40 "*.go",
41 } {
42 files, err := fs.Glob(fsys, pattern)
43 if err != nil {
44 return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err)
45 }
46 for _, fullpath := range files {
47 base := filepath.Base(fullpath)
48 if strings.HasSuffix(base, "_test.go") {
49 continue
50 }
51 if excludePaths[base] {
52 // TODO(mf): log this?
53 continue
54 }
55 // If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
56 // that as the version. Otherwise, ignore it. This allows users to have arbitrary
57 // filenames, but still have versioned migrations within the same directory. For
58 // example, a user could have a helpers.go file which contains unexported helper
59 // functions for migrations.
60 version, err := NumericComponent(base)
61 if err != nil {
62 if strict {
63 return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
64 }
65 continue
66 }
67 if excludeVersions[version] {
68 // TODO: log this?
69 continue
70 }
71 // Ensure there are no duplicate versions.
72 if existing, ok := versionToBaseLookup[version]; ok {
73 return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
74 version,
75 existing,
76 base,
77 )
78 }
79 switch filepath.Ext(base) {
80 case ".sql":
81 sources.sqlSources = append(sources.sqlSources, Source{
82 Type: TypeSQL,
83 Path: fullpath,
84 Version: version,
85 })
86 case ".go":
87 sources.goSources = append(sources.goSources, Source{
88 Type: TypeGo,
89 Path: fullpath,
90 Version: version,
91 })
92 default:
93 // Should never happen since we already filtered out all other file types.
94 return nil, fmt.Errorf("invalid file extension: %q", base)
95 }
96 // Add the version to the lookup map.
97 versionToBaseLookup[version] = base
98 }
99 }
100 return sources, nil
101}
102
103func newSQLMigration(source Source) *Migration {
104 return &Migration{
105 Type: source.Type,
106 Version: source.Version,
107 Source: source.Path,
108 construct: true,
109 Next: -1, Previous: -1,
110 sql: sqlMigration{
111 Parsed: false, // SQL migrations are parsed lazily.
112 },
113 }
114}
115
116func merge(sources *fileSources, registered map[int64]*Migration) ([]*Migration, error) {
117 var migrations []*Migration
118 migrationLookup := make(map[int64]*Migration)
119 // Add all SQL migrations to the list of migrations.
120 for _, source := range sources.sqlSources {
121 m := newSQLMigration(source)
122 migrations = append(migrations, m)
123 migrationLookup[source.Version] = m
124 }
125 // If there are no Go files in the filesystem and no registered Go migrations, return early.
126 if len(sources.goSources) == 0 && len(registered) == 0 {
127 return migrations, nil
128 }
129 // Return an error if the given sources contain a versioned Go migration that has not been
130 // registered. This is a sanity check to ensure users didn't accidentally create a valid looking
131 // Go migration file on disk and forget to register it.
132 //
133 // This is almost always a user error.
134 var unregistered []string
135 for _, s := range sources.goSources {
136 m, ok := registered[s.Version]
137 if !ok {
138 unregistered = append(unregistered, s.Path)
139 } else {
140 // Populate the source path for registered Go migrations that have a corresponding file
141 // on disk.
142 m.Source = s.Path
143 }
144 }
145 if len(unregistered) > 0 {
146 return nil, unregisteredError(unregistered)
147 }
148 // Add all registered Go migrations to the list of migrations, checking for duplicate versions.
149 //
150 // Important, users can register Go migrations manually via goose.Add_ functions. These
151 // migrations may not have a corresponding file on disk. Which is fine! We include them
152 // wholesale as part of migrations. This allows users to build a custom binary that only embeds
153 // the SQL migration files.
154 for version, r := range registered {
155 // Ensure there are no duplicate versions.
156 if existing, ok := migrationLookup[version]; ok {
157 fullpath := r.Source
158 if fullpath == "" {
159 fullpath = "no source path"
160 }
161 return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
162 version,
163 existing.Source,
164 fullpath,
165 )
166 }
167 migrations = append(migrations, r)
168 migrationLookup[version] = r
169 }
170 // Sort migrations by version in ascending order.
171 sort.Slice(migrations, func(i, j int) bool {
172 return migrations[i].Version < migrations[j].Version
173 })
174 return migrations, nil
175}
176
177func unregisteredError(unregistered []string) error {
178 const (
179 hintURL = "https://github.com/pressly/goose/tree/master/examples/go-migrations"
180 )
181 f := "file"
182 if len(unregistered) > 1 {
183 f += "s"
184 }
185 var b strings.Builder
186
187 b.WriteString(fmt.Sprintf("error: detected %d unregistered Go %s:\n", len(unregistered), f))
188 for _, name := range unregistered {
189 b.WriteString("\t" + name + "\n")
190 }
191 hint := fmt.Sprintf("hint: go functions must be registered and built into a custom binary see:\n%s", hintURL)
192 b.WriteString(hint)
193 b.WriteString("\n")
194
195 return errors.New(b.String())
196}