provider_run.go

  1package goose
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8	"io/fs"
  9	"runtime/debug"
 10	"strings"
 11	"time"
 12
 13	"github.com/pressly/goose/v3/database"
 14	"github.com/pressly/goose/v3/internal/sqlparser"
 15	"github.com/sethvargo/go-retry"
 16	"go.uber.org/multierr"
 17)
 18
 19var (
 20	errMissingZeroVersion = errors.New("missing zero version migration")
 21)
 22
 23func (p *Provider) prepareMigration(fsys fs.FS, m *Migration, direction bool) error {
 24	switch m.Type {
 25	case TypeGo:
 26		if m.goUp.Mode == 0 {
 27			return errors.New("go up migration mode is not set")
 28		}
 29		if m.goDown.Mode == 0 {
 30			return errors.New("go down migration mode is not set")
 31		}
 32		var useTx bool
 33		if direction {
 34			useTx = m.goUp.Mode == TransactionEnabled
 35		} else {
 36			useTx = m.goDown.Mode == TransactionEnabled
 37		}
 38		// bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB,
 39		// but are locking the database with *sql.Conn. If the caller sets max open connections to
 40		// 1, then this will deadlock because the Go migration will try to acquire a connection from
 41		// the pool, but the pool is exhausted because the lock is held.
 42		//
 43		// A potential solution is to expose a third Go register function *sql.Conn. Or continue to
 44		// use *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is
 45		// a bit of an edge case. For now, we guard against this scenario by checking the max open
 46		// connections and returning an error.
 47		if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 {
 48			if !useTx {
 49				return errors.New("potential deadlock detected: cannot run Go migration without a transaction when max open connections set to 1")
 50			}
 51		}
 52		return nil
 53	case TypeSQL:
 54		if m.sql.Parsed {
 55			return nil
 56		}
 57		parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source, false)
 58		if err != nil {
 59			return err
 60		}
 61		m.sql.Parsed = true
 62		m.sql.UseTx = parsed.UseTx
 63		m.sql.Up, m.sql.Down = parsed.Up, parsed.Down
 64		return nil
 65	}
 66	return fmt.Errorf("invalid migration type: %+v", m)
 67}
 68
 69// printf is a helper function that prints the given message if verbose is enabled. It also prepends
 70// the "goose: " prefix to the message.
 71func (p *Provider) printf(msg string, args ...interface{}) {
 72	if p.cfg.verbose {
 73		if !strings.HasPrefix(msg, "goose:") {
 74			msg = "goose: " + msg
 75		}
 76		p.cfg.logger.Printf(msg, args...)
 77	}
 78}
 79
 80// runMigrations runs migrations sequentially in the given direction. If the migrations list is
 81// empty, return nil without error.
 82func (p *Provider) runMigrations(
 83	ctx context.Context,
 84	conn *sql.Conn,
 85	migrations []*Migration,
 86	direction sqlparser.Direction,
 87	byOne bool,
 88) ([]*MigrationResult, error) {
 89	if len(migrations) == 0 {
 90		if !p.cfg.disableVersioning {
 91			// No need to print this message if versioning is disabled because there are no
 92			// migrations being tracked in the goose version table.
 93			maxVersion, err := p.getDBMaxVersion(ctx, conn)
 94			if err != nil {
 95				return nil, err
 96			}
 97			p.printf("no migrations to run, current version: %d", maxVersion)
 98		}
 99		return nil, nil
100	}
101	apply := migrations
102	if byOne {
103		apply = migrations[:1]
104	}
105
106	// SQL migrations are lazily parsed in both directions. This is done before attempting to run
107	// any migrations to catch errors early and prevent leaving the database in an incomplete state.
108
109	for _, m := range apply {
110		if err := p.prepareMigration(p.fsys, m, direction.ToBool()); err != nil {
111			return nil, fmt.Errorf("failed to prepare migration %s: %w", m.ref(), err)
112		}
113	}
114
115	// feat(mf): If we decide to add support for advisory locks at the transaction level, this may
116	// be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe
117	// to run in a transaction.
118
119	// feat(mf): this is where we can (optionally) group multiple migrations to be run in a single
120	// transaction. The default is to apply each migration sequentially on its own. See the
121	// following issues for more details:
122	//  - https://github.com/pressly/goose/issues/485
123	//  - https://github.com/pressly/goose/issues/222
124	//
125	// Be careful, we can't use a single transaction for all migrations because some may be marked
126	// as not using a transaction.
127
128	var results []*MigrationResult
129	for _, m := range apply {
130		result := &MigrationResult{
131			Source: &Source{
132				Type:    m.Type,
133				Path:    m.Source,
134				Version: m.Version,
135			},
136			Direction: direction.String(),
137			Empty:     isEmpty(m, direction.ToBool()),
138		}
139		start := time.Now()
140		if err := p.runIndividually(ctx, conn, m, direction.ToBool()); err != nil {
141			// TODO(mf): we should also return the pending migrations here, the remaining items in
142			// the apply slice.
143			result.Error = err
144			result.Duration = time.Since(start)
145			return nil, &PartialError{
146				Applied: results,
147				Failed:  result,
148				Err:     err,
149			}
150		}
151		result.Duration = time.Since(start)
152		results = append(results, result)
153		p.printf("%s", result)
154	}
155	if !p.cfg.disableVersioning && !byOne {
156		maxVersion, err := p.getDBMaxVersion(ctx, conn)
157		if err != nil {
158			return nil, err
159		}
160		p.printf("successfully migrated database, current version: %d", maxVersion)
161	}
162	return results, nil
163}
164
165func (p *Provider) runIndividually(
166	ctx context.Context,
167	conn *sql.Conn,
168	m *Migration,
169	direction bool,
170) error {
171	useTx, err := useTx(m, direction)
172	if err != nil {
173		return err
174	}
175	if useTx {
176		return beginTx(ctx, conn, func(tx *sql.Tx) error {
177			if err := p.runMigration(ctx, tx, m, direction); err != nil {
178				return err
179			}
180			return p.maybeInsertOrDelete(ctx, tx, m.Version, direction)
181		})
182	}
183	switch m.Type {
184	case TypeGo:
185		// Note, we are using *sql.DB instead of *sql.Conn because it's the Go migration contract.
186		// This may be a deadlock scenario if max open connections is set to 1 AND a lock is
187		// acquired on the database. In this case, the migration will block forever unable to
188		// acquire a connection from the pool.
189		//
190		// For now, we guard against this scenario by checking the max open connections and
191		// returning an error in the prepareMigration function.
192		if err := p.runMigration(ctx, p.db, m, direction); err != nil {
193			return err
194		}
195		return p.maybeInsertOrDelete(ctx, p.db, m.Version, direction)
196	case TypeSQL:
197		if err := p.runMigration(ctx, conn, m, direction); err != nil {
198			return err
199		}
200		return p.maybeInsertOrDelete(ctx, conn, m.Version, direction)
201	}
202	return fmt.Errorf("failed to run individual migration: neither sql or go: %v", m)
203}
204
205func (p *Provider) maybeInsertOrDelete(
206	ctx context.Context,
207	db database.DBTxConn,
208	version int64,
209	direction bool,
210) error {
211	// If versioning is disabled, we don't need to insert or delete the migration version.
212	if p.cfg.disableVersioning {
213		return nil
214	}
215	if direction {
216		return p.store.Insert(ctx, db, database.InsertRequest{Version: version})
217	}
218	return p.store.Delete(ctx, db, version)
219}
220
221// beginTx begins a transaction and runs the given function. If the function returns an error, the
222// transaction is rolled back. Otherwise, the transaction is committed.
223func beginTx(ctx context.Context, conn *sql.Conn, fn func(tx *sql.Tx) error) (retErr error) {
224	tx, err := conn.BeginTx(ctx, nil)
225	if err != nil {
226		return err
227	}
228	defer func() {
229		if retErr != nil {
230			retErr = multierr.Append(retErr, tx.Rollback())
231		}
232	}()
233	if err := fn(tx); err != nil {
234		return err
235	}
236	return tx.Commit()
237}
238
239func (p *Provider) initialize(ctx context.Context, useSessionLocker bool) (*sql.Conn, func() error, error) {
240	p.mu.Lock()
241	conn, err := p.db.Conn(ctx)
242	if err != nil {
243		p.mu.Unlock()
244		return nil, nil, err
245	}
246	// cleanup is a function that cleans up the connection, and optionally, the session lock.
247	cleanup := func() error {
248		p.mu.Unlock()
249		return conn.Close()
250	}
251	if useSessionLocker && p.cfg.sessionLocker != nil && p.cfg.lockEnabled {
252		l := p.cfg.sessionLocker
253		if err := l.SessionLock(ctx, conn); err != nil {
254			return nil, nil, multierr.Append(err, cleanup())
255		}
256		// A lock was acquired, so we need to unlock the session when we're done. This is done by
257		// returning a cleanup function that unlocks the session and closes the connection.
258		cleanup = func() error {
259			p.mu.Unlock()
260			// Use a detached context to unlock the session. This is because the context passed to
261			// SessionLock may have been canceled, and we don't want to cancel the unlock.
262			return multierr.Append(l.SessionUnlock(context.WithoutCancel(ctx), conn), conn.Close())
263		}
264	}
265	// If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't
266	// need the version table because no versions are being tracked.
267	if !p.cfg.disableVersioning {
268		if err := p.ensureVersionTable(ctx, conn); err != nil {
269			return nil, nil, multierr.Append(err, cleanup())
270		}
271	}
272	return conn, cleanup, nil
273}
274
275func (p *Provider) ensureVersionTable(
276	ctx context.Context,
277	conn *sql.Conn,
278) (retErr error) {
279	// There are 2 optimizations here:
280	//  - 1. We create the version table once per Provider instance.
281	//  - 2. We retry the operation a few times in case the table is being created concurrently.
282	//
283	// Regarding item 2, certain goose operations, like HasPending, don't respect a SessionLocker.
284	// So, when goose is run for the first time in a multi-instance environment, it's possible that
285	// multiple instances will try to create the version table at the same time. This is why we
286	// retry this operation a few times. Best case, the table is created by one instance and all the
287	// other instances see that change immediately. Worst case, all instances try to create the
288	// table at the same time, but only one will succeed and the others will retry.
289	p.versionTableOnce.Do(func() {
290		retErr = p.tryEnsureVersionTable(ctx, conn)
291	})
292	return retErr
293}
294
295func (p *Provider) tryEnsureVersionTable(ctx context.Context, conn *sql.Conn) error {
296	b := retry.NewConstant(1 * time.Second)
297	b = retry.WithMaxRetries(3, b)
298	return retry.Do(ctx, b, func(ctx context.Context) error {
299		exists, err := p.store.TableExists(ctx, conn)
300		if err == nil && exists {
301			return nil
302		} else if err != nil && errors.Is(err, errors.ErrUnsupported) {
303			// Fallback strategy for checking table existence:
304			//
305			// When direct table existence checks aren't supported, we attempt to query the initial
306			// migration (version 0). This approach has two implications:
307			//
308			//  1. If the table exists, the query succeeds and confirms existence
309			//  2. If the table doesn't exist, the query fails and generates an error log
310			//
311			// Note: This check must occur outside any transaction, as a failed query would
312			// otherwise cause the entire transaction to roll back. The error logs generated by this
313			// approach are expected and can be safely ignored.
314			if res, err := p.store.GetMigration(ctx, conn, 0); err == nil && res != nil {
315				return nil
316			}
317			// Fallthrough to create the table.
318		} else if err != nil {
319			return fmt.Errorf("failed to check if version table exists: %w", err)
320		}
321
322		if err := beginTx(ctx, conn, func(tx *sql.Tx) error {
323			if err := p.store.CreateVersionTable(ctx, tx); err != nil {
324				return err
325			}
326			return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
327		}); err != nil {
328			// Mark the error as retryable so we can try again. It's possible that another instance
329			// is creating the table at the same time and the checks above will succeed on the next
330			// iteration.
331			return retry.RetryableError(fmt.Errorf("failed to create version table: %w", err))
332		}
333		return nil
334	})
335}
336
337// getMigration returns the migration for the given version. If no migration is found, then
338// ErrVersionNotFound is returned.
339func (p *Provider) getMigration(version int64) (*Migration, error) {
340	for _, m := range p.migrations {
341		if m.Version == version {
342			return m, nil
343		}
344	}
345	return nil, ErrVersionNotFound
346}
347
348// useTx is a helper function that returns true if the migration should be run in a transaction. It
349// must only be called after the migration has been parsed and initialized.
350func useTx(m *Migration, direction bool) (bool, error) {
351	switch m.Type {
352	case TypeGo:
353		if m.goUp.Mode == 0 || m.goDown.Mode == 0 {
354			return false, fmt.Errorf("go migrations must have a mode set")
355		}
356		if direction {
357			return m.goUp.Mode == TransactionEnabled, nil
358		}
359		return m.goDown.Mode == TransactionEnabled, nil
360	case TypeSQL:
361		if !m.sql.Parsed {
362			return false, fmt.Errorf("sql migrations must be parsed")
363		}
364		return m.sql.UseTx, nil
365	}
366	return false, fmt.Errorf("use tx: invalid migration type: %q", m.Type)
367}
368
369// isEmpty is a helper function that returns true if the migration has no functions or no statements
370// to execute. It must only be called after the migration has been parsed and initialized.
371func isEmpty(m *Migration, direction bool) bool {
372	switch m.Type {
373	case TypeGo:
374		if direction {
375			return m.goUp.RunTx == nil && m.goUp.RunDB == nil
376		}
377		return m.goDown.RunTx == nil && m.goDown.RunDB == nil
378	case TypeSQL:
379		if direction {
380			return len(m.sql.Up) == 0
381		}
382		return len(m.sql.Down) == 0
383	}
384	return true
385}
386
387// runMigration is a helper function that runs the migration in the given direction. It must only be
388// called after the migration has been parsed and initialized.
389func (p *Provider) runMigration(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error {
390	switch m.Type {
391	case TypeGo:
392		return p.runGo(ctx, db, m, direction)
393	case TypeSQL:
394		return p.runSQL(ctx, db, m, direction)
395	}
396	return fmt.Errorf("invalid migration type: %q", m.Type)
397}
398
399// runGo is a helper function that runs the given Go functions in the given direction. It must only
400// be called after the migration has been initialized.
401func (p *Provider) runGo(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) (retErr error) {
402	defer func() {
403		if r := recover(); r != nil {
404			retErr = fmt.Errorf("panic: %v\n%s", r, debug.Stack())
405		}
406	}()
407
408	switch db := db.(type) {
409	case *sql.Conn:
410		return fmt.Errorf("go migrations are not supported with *sql.Conn")
411	case *sql.DB:
412		if direction && m.goUp.RunDB != nil {
413			return m.goUp.RunDB(ctx, db)
414		}
415		if !direction && m.goDown.RunDB != nil {
416			return m.goDown.RunDB(ctx, db)
417		}
418		return nil
419	case *sql.Tx:
420		if direction && m.goUp.RunTx != nil {
421			return m.goUp.RunTx(ctx, db)
422		}
423		if !direction && m.goDown.RunTx != nil {
424			return m.goDown.RunTx(ctx, db)
425		}
426		return nil
427	}
428	return fmt.Errorf("invalid database connection type: %T", db)
429}
430
431// runSQL is a helper function that runs the given SQL statements in the given direction. It must
432// only be called after the migration has been parsed.
433func (p *Provider) runSQL(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error {
434
435	if !m.sql.Parsed {
436		return fmt.Errorf("sql migrations must be parsed")
437	}
438	var statements []string
439	if direction {
440		statements = m.sql.Up
441	} else {
442		statements = m.sql.Down
443	}
444	for _, stmt := range statements {
445		if p.cfg.verbose {
446			p.cfg.logger.Printf("Excuting statement: %s", stmt)
447		}
448		if _, err := db.ExecContext(ctx, stmt); err != nil {
449			return err
450		}
451	}
452	return nil
453}