migration.go

  1package goose
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8	"path/filepath"
  9	"strconv"
 10	"strings"
 11	"time"
 12
 13	"github.com/pressly/goose/v3/internal/sqlparser"
 14)
 15
 16// NewGoMigration creates a new Go migration.
 17//
 18// Both up and down functions may be nil, in which case the migration will be recorded in the
 19// versions table but no functions will be run. This is useful for recording (up) or deleting (down)
 20// a version without running any functions. See [GoFunc] for more details.
 21func NewGoMigration(version int64, up, down *GoFunc) *Migration {
 22	m := &Migration{
 23		Type:       TypeGo,
 24		Registered: true,
 25		Version:    version,
 26		Next:       -1, Previous: -1,
 27		goUp:      &GoFunc{Mode: TransactionEnabled},
 28		goDown:    &GoFunc{Mode: TransactionEnabled},
 29		construct: true,
 30	}
 31	updateMode := func(f *GoFunc) *GoFunc {
 32		// infer mode from function
 33		if f.Mode == 0 {
 34			if f.RunTx != nil && f.RunDB == nil {
 35				f.Mode = TransactionEnabled
 36			}
 37			if f.RunTx == nil && f.RunDB != nil {
 38				f.Mode = TransactionDisabled
 39			}
 40			// Always default to TransactionEnabled if both functions are nil. This is the most
 41			// common use case.
 42			if f.RunDB == nil && f.RunTx == nil {
 43				f.Mode = TransactionEnabled
 44			}
 45		}
 46		return f
 47	}
 48	// To maintain backwards compatibility, we set ALL legacy functions. In a future major version,
 49	// we will remove these fields in favor of [GoFunc].
 50	//
 51	// Note, this function does not do any validation. Validation is lazily done when the migration
 52	// is registered.
 53	if up != nil {
 54		m.goUp = updateMode(up)
 55
 56		if up.RunDB != nil {
 57			m.UpFnNoTxContext = up.RunDB          // func(context.Context, *sql.DB) error
 58			m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error
 59		}
 60		if up.RunTx != nil {
 61			m.UseTx = true
 62			m.UpFnContext = up.RunTx          // func(context.Context, *sql.Tx) error
 63			m.UpFn = withoutContext(up.RunTx) // func(*sql.Tx) error
 64		}
 65	}
 66	if down != nil {
 67		m.goDown = updateMode(down)
 68
 69		if down.RunDB != nil {
 70			m.DownFnNoTxContext = down.RunDB          // func(context.Context, *sql.DB) error
 71			m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error
 72		}
 73		if down.RunTx != nil {
 74			m.UseTx = true
 75			m.DownFnContext = down.RunTx          // func(context.Context, *sql.Tx) error
 76			m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error
 77		}
 78	}
 79	return m
 80}
 81
 82// Migration struct represents either a SQL or Go migration.
 83//
 84// Avoid constructing migrations manually, use [NewGoMigration] function.
 85type Migration struct {
 86	Type    MigrationType
 87	Version int64
 88	// Source is the path to the .sql script or .go file. It may be empty for Go migrations that
 89	// have been registered globally and don't have a source file.
 90	Source string
 91
 92	UpFnContext, DownFnContext         GoMigrationContext
 93	UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
 94
 95	// These fields will be removed in a future major version. They are here for backwards
 96	// compatibility and are an implementation detail.
 97	Registered bool
 98	UseTx      bool
 99	Next       int64 // next version, or -1 if none
100	Previous   int64 // previous version, -1 if none
101
102	// We still save the non-context versions in the struct in case someone is using them. Goose
103	// does not use these internally anymore in favor of the context-aware versions. These fields
104	// will be removed in a future major version.
105
106	UpFn       GoMigration     // Deprecated: use UpFnContext instead.
107	DownFn     GoMigration     // Deprecated: use DownFnContext instead.
108	UpFnNoTx   GoMigrationNoTx // Deprecated: use UpFnNoTxContext instead.
109	DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead.
110
111	noVersioning bool
112
113	// These fields are used internally by goose and users are not expected to set them. Instead,
114	// use [NewGoMigration] to create a new go migration.
115	construct    bool
116	goUp, goDown *GoFunc
117
118	sql sqlMigration
119}
120
121type sqlMigration struct {
122	// The Parsed field is used to track whether the SQL migration has been parsed. It serves as an
123	// optimization to avoid parsing migrations that may never be needed. Typically, migrations are
124	// incremental, and users often run only the most recent ones, making parsing of prior
125	// migrations unnecessary in most cases.
126	Parsed bool
127
128	// Parsed must be set to true before the following fields are used.
129	UseTx bool
130	Up    []string
131	Down  []string
132}
133
134// GoFunc represents a Go migration function.
135type GoFunc struct {
136	// Exactly one of these must be set, or both must be nil.
137	RunTx func(ctx context.Context, tx *sql.Tx) error
138	// -- OR --
139	RunDB func(ctx context.Context, db *sql.DB) error
140
141	// Mode is the transaction mode for the migration. When one of the run functions is set, the
142	// mode will be inferred from the function and the field is ignored. Users do not need to set
143	// this field when supplying a run function.
144	//
145	// If both run functions are nil, the mode defaults to TransactionEnabled. The use case for nil
146	// functions is to record a version in the version table without invoking a Go migration
147	// function.
148	//
149	// The only time this field is required is if BOTH run functions are nil AND you want to
150	// override the default transaction mode.
151	Mode TransactionMode
152}
153
154// TransactionMode represents the possible transaction modes for a migration.
155type TransactionMode int
156
157const (
158	TransactionEnabled TransactionMode = iota + 1
159	TransactionDisabled
160)
161
162func (m TransactionMode) String() string {
163	switch m {
164	case TransactionEnabled:
165		return "transaction_enabled"
166	case TransactionDisabled:
167		return "transaction_disabled"
168	default:
169		return fmt.Sprintf("unknown transaction mode (%d)", m)
170	}
171}
172
173// MigrationRecord struct.
174//
175// Deprecated: unused and will be removed in a future major version.
176type MigrationRecord struct {
177	VersionID int64
178	TStamp    time.Time
179	IsApplied bool // was this a result of up() or down()
180}
181
182func (m *Migration) String() string {
183	return fmt.Sprint(m.Source)
184}
185
186// Up runs an up migration.
187func (m *Migration) Up(db *sql.DB) error {
188	ctx := context.Background()
189	return m.UpContext(ctx, db)
190}
191
192// UpContext runs an up migration.
193func (m *Migration) UpContext(ctx context.Context, db *sql.DB) error {
194	if err := m.run(ctx, db, true); err != nil {
195		return err
196	}
197	return nil
198}
199
200// Down runs a down migration.
201func (m *Migration) Down(db *sql.DB) error {
202	ctx := context.Background()
203	return m.DownContext(ctx, db)
204}
205
206// DownContext runs a down migration.
207func (m *Migration) DownContext(ctx context.Context, db *sql.DB) error {
208	if err := m.run(ctx, db, false); err != nil {
209		return err
210	}
211	return nil
212}
213
214func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
215	switch filepath.Ext(m.Source) {
216	case ".sql":
217		f, err := baseFS.Open(m.Source)
218		if err != nil {
219			return fmt.Errorf("ERROR %v: failed to open SQL migration file: %w", filepath.Base(m.Source), err)
220		}
221		defer f.Close()
222
223		statements, useTx, err := sqlparser.ParseSQLMigration(f, sqlparser.FromBool(direction), verbose)
224		if err != nil {
225			return fmt.Errorf("ERROR %v: failed to parse SQL migration file: %w", filepath.Base(m.Source), err)
226		}
227
228		start := time.Now()
229		if err := runSQLMigration(ctx, db, statements, useTx, m.Version, direction, m.noVersioning); err != nil {
230			return fmt.Errorf("ERROR %v: failed to run SQL migration: %w", filepath.Base(m.Source), err)
231		}
232		finish := truncateDuration(time.Since(start))
233
234		if len(statements) > 0 {
235			log.Printf("OK   %s (%s)", filepath.Base(m.Source), finish)
236		} else {
237			log.Printf("EMPTY %s (%s)", filepath.Base(m.Source), finish)
238		}
239
240	case ".go":
241		if !m.Registered {
242			return fmt.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
243		}
244		start := time.Now()
245		var empty bool
246		if m.UseTx {
247			// Run go-based migration inside a tx.
248			fn := m.DownFnContext
249			if direction {
250				fn = m.UpFnContext
251			}
252			empty = (fn == nil)
253			if err := runGoMigration(
254				ctx,
255				db,
256				fn,
257				m.Version,
258				direction,
259				!m.noVersioning,
260			); err != nil {
261				return fmt.Errorf("ERROR go migration: %q: %w", filepath.Base(m.Source), err)
262			}
263		} else {
264			// Run go-based migration outside a tx.
265			fn := m.DownFnNoTxContext
266			if direction {
267				fn = m.UpFnNoTxContext
268			}
269			empty = (fn == nil)
270			if err := runGoMigrationNoTx(
271				ctx,
272				db,
273				fn,
274				m.Version,
275				direction,
276				!m.noVersioning,
277			); err != nil {
278				return fmt.Errorf("ERROR go migration no tx: %q: %w", filepath.Base(m.Source), err)
279			}
280		}
281		finish := truncateDuration(time.Since(start))
282		if !empty {
283			log.Printf("OK   %s (%s)", filepath.Base(m.Source), finish)
284		} else {
285			log.Printf("EMPTY %s (%s)", filepath.Base(m.Source), finish)
286		}
287	}
288	return nil
289}
290
291func runGoMigrationNoTx(
292	ctx context.Context,
293	db *sql.DB,
294	fn GoMigrationNoTxContext,
295	version int64,
296	direction bool,
297	recordVersion bool,
298) error {
299	if fn != nil {
300		// Run go migration function.
301		if err := fn(ctx, db); err != nil {
302			return fmt.Errorf("failed to run go migration: %w", err)
303		}
304	}
305	if recordVersion {
306		return insertOrDeleteVersionNoTx(ctx, db, version, direction)
307	}
308	return nil
309}
310
311func runGoMigration(
312	ctx context.Context,
313	db *sql.DB,
314	fn GoMigrationContext,
315	version int64,
316	direction bool,
317	recordVersion bool,
318) error {
319	if fn == nil && !recordVersion {
320		return nil
321	}
322	tx, err := db.BeginTx(ctx, nil)
323	if err != nil {
324		return fmt.Errorf("failed to begin transaction: %w", err)
325	}
326	if fn != nil {
327		// Run go migration function.
328		if err := fn(ctx, tx); err != nil {
329			_ = tx.Rollback()
330			return fmt.Errorf("failed to run go migration: %w", err)
331		}
332	}
333	if recordVersion {
334		if err := insertOrDeleteVersion(ctx, tx, version, direction); err != nil {
335			_ = tx.Rollback()
336			return fmt.Errorf("failed to update version: %w", err)
337		}
338	}
339	if err := tx.Commit(); err != nil {
340		return fmt.Errorf("failed to commit transaction: %w", err)
341	}
342	return nil
343}
344
345func insertOrDeleteVersion(ctx context.Context, tx *sql.Tx, version int64, direction bool) error {
346	if direction {
347		return store.InsertVersion(ctx, tx, TableName(), version)
348	}
349	return store.DeleteVersion(ctx, tx, TableName(), version)
350}
351
352func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, direction bool) error {
353	if direction {
354		return store.InsertVersionNoTx(ctx, db, TableName(), version)
355	}
356	return store.DeleteVersionNoTx(ctx, db, TableName(), version)
357}
358
359// NumericComponent parses the version from the migration file name.
360//
361// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of
362// migration, either .sql or .go.
363func NumericComponent(filename string) (int64, error) {
364	base := filepath.Base(filename)
365	if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
366		return 0, errors.New("migration file does not have .sql or .go file extension")
367	}
368	idx := strings.Index(base, "_")
369	if idx < 0 {
370		return 0, errors.New("no filename separator '_' found")
371	}
372	n, err := strconv.ParseInt(base[:idx], 10, 64)
373	if err != nil {
374		return 0, fmt.Errorf("failed to parse version from migration file: %s: %w", base, err)
375	}
376	if n < 1 {
377		return 0, errors.New("migration version must be greater than zero")
378	}
379	return n, nil
380}
381
382func truncateDuration(d time.Duration) time.Duration {
383	for _, v := range []time.Duration{
384		time.Second,
385		time.Millisecond,
386		time.Microsecond,
387	} {
388		if d > v {
389			return d.Round(v / time.Duration(100))
390		}
391	}
392	return d
393}
394
395// ref returns a string that identifies the migration. This is used for logging and error messages.
396func (m *Migration) ref() string {
397	return fmt.Sprintf("(type:%s,version:%d)", m.Type, m.Version)
398}