migrate.go

  1package migrate
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8
  9	"github.com/charmbracelet/log/v2"
 10	"github.com/charmbracelet/soft-serve/pkg/db"
 11)
 12
 13const (
 14	postgresDriver = "postgres"
 15	sqliteDriver   = "sqlite"
 16	sqlite3Driver  = "sqlite3"
 17)
 18
 19// MigrateFunc is a function that executes a migration.
 20type MigrateFunc func(ctx context.Context, tx *db.Tx) error //nolint:revive
 21
 22// Migration is a struct that contains the name of the migration and the
 23// function to execute it.
 24type Migration struct {
 25	Version  int64
 26	Name     string
 27	Migrate  MigrateFunc
 28	Rollback MigrateFunc
 29}
 30
 31// Migrations is a database model to store migrations.
 32type Migrations struct {
 33	ID      int64  `db:"id"`
 34	Name    string `db:"name"`
 35	Version int64  `db:"version"`
 36}
 37
 38func (Migrations) schema(driverName string) string {
 39	switch driverName {
 40	case sqlite3Driver, sqliteDriver:
 41		return `CREATE TABLE IF NOT EXISTS migrations (
 42				id INTEGER PRIMARY KEY AUTOINCREMENT,
 43				name TEXT NOT NULL,
 44				version INTEGER NOT NULL UNIQUE
 45			);
 46		`
 47	case postgresDriver:
 48		return `CREATE TABLE IF NOT EXISTS migrations (
 49			id SERIAL PRIMARY KEY,
 50			name TEXT NOT NULL,
 51			version INTEGER NOT NULL UNIQUE
 52		);
 53	`
 54	case "mysql":
 55		return `CREATE TABLE IF NOT EXISTS migrations (
 56			id INT NOT NULL AUTO_INCREMENT,
 57			name TEXT NOT NULL,
 58			version INT NOT NULL,
 59			UNIQUE (version),
 60			PRIMARY KEY (id)
 61		);
 62	`
 63	default:
 64		panic("unknown driver")
 65	}
 66}
 67
 68// Migrate runs the migrations.
 69func Migrate(ctx context.Context, dbx *db.DB) error {
 70	logger := log.FromContext(ctx).WithPrefix("migrate")
 71	return dbx.TransactionContext(ctx, func(tx *db.Tx) error { //nolint:wrapcheck
 72		if !hasTable(tx, "migrations") {
 73			if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil {
 74				return err //nolint:wrapcheck
 75			}
 76		}
 77
 78		var migrs Migrations
 79		if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
 80			if !errors.Is(err, sql.ErrNoRows) {
 81				return err //nolint:wrapcheck
 82			}
 83		}
 84
 85		for _, m := range migrations {
 86			if m.Version <= migrs.Version {
 87				continue
 88			}
 89
 90			logger.Infof("running migration %d. %s", m.Version, m.Name)
 91			if err := m.Migrate(ctx, tx); err != nil {
 92				return err
 93			}
 94
 95			if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil {
 96				return err //nolint:wrapcheck
 97			}
 98		}
 99
100		return nil
101	})
102}
103
104// Rollback rolls back a migration.
105func Rollback(ctx context.Context, dbx *db.DB) error {
106	logger := log.FromContext(ctx).WithPrefix("migrate")
107	return dbx.TransactionContext(ctx, func(tx *db.Tx) error { //nolint:wrapcheck
108		var migrs Migrations
109		if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
110			if !errors.Is(err, sql.ErrNoRows) {
111				return fmt.Errorf("there are no migrations to rollback: %w", err)
112			}
113		}
114
115		if migrs.Version == 0 || len(migrations) < int(migrs.Version) {
116			return fmt.Errorf("there are no migrations to rollback")
117		}
118
119		m := migrations[migrs.Version-1]
120		logger.Infof("rolling back migration %d. %s", m.Version, m.Name)
121		if err := m.Rollback(ctx, tx); err != nil {
122			return err
123		}
124
125		if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE version = ?"), migrs.Version); err != nil {
126			return err //nolint:wrapcheck
127		}
128
129		return nil
130	})
131}
132
133func hasTable(tx *db.Tx, tableName string) bool {
134	var query string
135	switch tx.DriverName() {
136	case sqlite3Driver, sqliteDriver:
137		query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
138	case postgresDriver:
139		fallthrough
140	case "mysql":
141		query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?"
142	}
143
144	query = tx.Rebind(query)
145	var name string
146	err := tx.Get(&name, query, tableName)
147	return err == nil
148}