migrate.go

  1package migrate
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8
  9	log "github.com/charmbracelet/log/v2"
 10	"github.com/charmbracelet/soft-serve/pkg/db"
 11)
 12
 13// MigrateFunc is a function that executes a migration.
 14type MigrateFunc func(ctx context.Context, tx *db.Tx) error //nolint:revive
 15
 16// Migration is a struct that contains the name of the migration and the
 17// function to execute it.
 18type Migration struct {
 19	Version  int64
 20	Name     string
 21	Migrate  MigrateFunc
 22	Rollback MigrateFunc
 23}
 24
 25// Migrations is a database model to store migrations.
 26type Migrations struct {
 27	ID      int64  `db:"id"`
 28	Name    string `db:"name"`
 29	Version int64  `db:"version"`
 30}
 31
 32func (Migrations) schema(driverName string) string {
 33	switch driverName {
 34	case driverSQLite3, driverSQLite:
 35		return `CREATE TABLE IF NOT EXISTS migrations (
 36				id INTEGER PRIMARY KEY AUTOINCREMENT,
 37				name TEXT NOT NULL,
 38				version INTEGER NOT NULL UNIQUE
 39			);
 40		`
 41	case "postgres":
 42		return `CREATE TABLE IF NOT EXISTS migrations (
 43			id SERIAL PRIMARY KEY,
 44			name TEXT NOT NULL,
 45			version INTEGER NOT NULL UNIQUE
 46		);
 47	`
 48	case "mysql":
 49		return `CREATE TABLE IF NOT EXISTS migrations (
 50			id INT NOT NULL AUTO_INCREMENT,
 51			name TEXT NOT NULL,
 52			version INT NOT NULL,
 53			UNIQUE (version),
 54			PRIMARY KEY (id)
 55		);
 56	`
 57	default:
 58		panic("unknown driver")
 59	}
 60}
 61
 62// Migrate runs the migrations.
 63func Migrate(ctx context.Context, dbx *db.DB) error {
 64	logger := log.FromContext(ctx).WithPrefix("migrate")
 65	return dbx.TransactionContext(ctx, func(tx *db.Tx) error {
 66		if !hasTable(tx, "migrations") {
 67			if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil {
 68				return err
 69			}
 70		}
 71
 72		var migrs Migrations
 73		if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
 74			if !errors.Is(err, sql.ErrNoRows) {
 75				return err
 76			}
 77		}
 78
 79		for _, m := range migrations {
 80			if m.Version <= migrs.Version {
 81				continue
 82			}
 83
 84			logger.Infof("running migration %d. %s", m.Version, m.Name)
 85			if err := m.Migrate(ctx, tx); err != nil {
 86				return err
 87			}
 88
 89			if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil {
 90				return err
 91			}
 92		}
 93
 94		return nil
 95	})
 96}
 97
 98// Rollback rolls back a migration.
 99func Rollback(ctx context.Context, dbx *db.DB) error {
100	logger := log.FromContext(ctx).WithPrefix("migrate")
101	return dbx.TransactionContext(ctx, func(tx *db.Tx) error {
102		var migrs Migrations
103		if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
104			if !errors.Is(err, sql.ErrNoRows) {
105				return fmt.Errorf("there are no migrations to rollback: %w", err)
106			}
107		}
108
109		if migrs.Version == 0 || len(migrations) < int(migrs.Version) {
110			return fmt.Errorf("there are no migrations to rollback")
111		}
112
113		m := migrations[migrs.Version-1]
114		logger.Infof("rolling back migration %d. %s", m.Version, m.Name)
115		if err := m.Rollback(ctx, tx); err != nil {
116			return err
117		}
118
119		if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE version = ?"), migrs.Version); err != nil {
120			return err
121		}
122
123		return nil
124	})
125}
126
127func hasTable(tx *db.Tx, tableName string) bool {
128	var query string
129	switch tx.DriverName() {
130	case driverSQLite3, driverSQLite:
131		query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
132	case driverPostgres:
133		fallthrough
134	case "mysql":
135		query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?"
136	}
137
138	query = tx.Rebind(query)
139	var name string
140	err := tx.Get(&name, query, tableName)
141	return err == nil
142}