migrate.go

  1package migrate
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8
  9	"github.com/charmbracelet/log"
 10	"github.com/charmbracelet/soft-serve/pkg/db"
 11)
 12
 13// MigrateFunc is a function that executes a migration.
 14type MigrateFunc func(ctx context.Context, h db.Handler) error // nolint:revive
 15
 16// Migration is a struct that contains the name of the migration and the
 17// function to execute it.
 18type Migration struct {
 19	Version     int64
 20	Name        string
 21	PreMigrate  MigrateFunc
 22	Migrate     MigrateFunc
 23	PostMigrate MigrateFunc
 24	Rollback    MigrateFunc
 25}
 26
 27// Migrations is a database model to store migrations.
 28type Migrations struct {
 29	ID      int64  `db:"id"`
 30	Name    string `db:"name"`
 31	Version int64  `db:"version"`
 32}
 33
 34func (Migrations) schema(driverName string) string {
 35	switch driverName {
 36	case "sqlite3", "sqlite":
 37		return `CREATE TABLE IF NOT EXISTS migrations (
 38				id INTEGER PRIMARY KEY AUTOINCREMENT,
 39				name TEXT NOT NULL,
 40				version INTEGER NOT NULL UNIQUE
 41			);
 42		`
 43	case "postgres":
 44		return `CREATE TABLE IF NOT EXISTS migrations (
 45			id SERIAL PRIMARY KEY,
 46			name TEXT NOT NULL,
 47			version INTEGER NOT NULL UNIQUE
 48		);
 49	`
 50	case "mysql":
 51		return `CREATE TABLE IF NOT EXISTS migrations (
 52			id INT NOT NULL AUTO_INCREMENT,
 53			name TEXT NOT NULL,
 54			version INT NOT NULL,
 55			UNIQUE (version),
 56			PRIMARY KEY (id)
 57		);
 58	`
 59	default:
 60		panic("unknown driver")
 61	}
 62}
 63
 64// Migrate runs the migrations.
 65func Migrate(ctx context.Context, dbx *db.DB) error {
 66	logger := log.FromContext(ctx).WithPrefix("migrate")
 67
 68	if !hasTable(dbx, "migrations") {
 69		if _, err := dbx.Exec(Migrations{}.schema(dbx.DriverName())); err != nil {
 70			return err
 71		}
 72	}
 73
 74	var migrs Migrations
 75	if err := dbx.Get(&migrs, dbx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
 76		if !errors.Is(err, sql.ErrNoRows) {
 77			return err
 78		}
 79	}
 80
 81	for _, m := range migrations {
 82		if m.Version <= migrs.Version {
 83			continue
 84		}
 85
 86		logger.Infof("running migration %d. %s", m.Version, m.Name)
 87		if m.PreMigrate != nil {
 88			if err := m.PreMigrate(ctx, dbx); err != nil {
 89				return err
 90			}
 91		}
 92
 93		if err := dbx.TransactionContext(ctx, func(tx *db.Tx) error {
 94			return m.Migrate(ctx, tx)
 95		}); err != nil {
 96			return err
 97		}
 98
 99		if m.PostMigrate != nil {
100			if err := m.PostMigrate(ctx, dbx); err != nil {
101				return err
102			}
103		}
104
105		if _, err := dbx.Exec(dbx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil {
106			return err
107		}
108	}
109
110	return nil
111}
112
113// Rollback rolls back a migration.
114func Rollback(ctx context.Context, dbx *db.DB) error {
115	logger := log.FromContext(ctx).WithPrefix("migrate")
116	return dbx.TransactionContext(ctx, func(tx *db.Tx) error {
117		var migrs Migrations
118		if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
119			if !errors.Is(err, sql.ErrNoRows) {
120				return fmt.Errorf("there are no migrations to rollback: %w", err)
121			}
122		}
123
124		if migrs.Version == 0 || len(migrations) < int(migrs.Version) {
125			return fmt.Errorf("there are no migrations to rollback")
126		}
127
128		m := migrations[migrs.Version-1]
129		logger.Infof("rolling back migration %d. %s", m.Version, m.Name)
130		if err := m.Rollback(ctx, tx); err != nil {
131			return err
132		}
133
134		if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE version = ?"), migrs.Version); err != nil {
135			return err
136		}
137
138		return nil
139	})
140}
141
142func hasTable(tx db.Handler, tableName string) bool {
143	var query string
144	switch tx.DriverName() {
145	case "sqlite3", "sqlite":
146		query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
147	case "postgres":
148		fallthrough
149	case "mysql":
150		query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?"
151	}
152
153	query = tx.Rebind(query)
154	var name string
155	err := tx.Get(&name, query, tableName)
156	return err == nil
157}