1package migrate
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8
9 log "github.com/charmbracelet/log/v2"
10 "github.com/charmbracelet/soft-serve/pkg/db"
11)
12
13// MigrateFunc is a function that executes a migration.
14type MigrateFunc func(ctx context.Context, tx *db.Tx) error //nolint:revive
15
16// Migration is a struct that contains the name of the migration and the
17// function to execute it.
18type Migration struct {
19 Version int64
20 Name string
21 Migrate MigrateFunc
22 Rollback MigrateFunc
23}
24
25// Migrations is a database model to store migrations.
26type Migrations struct {
27 ID int64 `db:"id"`
28 Name string `db:"name"`
29 Version int64 `db:"version"`
30}
31
32func (Migrations) schema(driverName string) string {
33 switch driverName {
34 case driverSQLite3, driverSQLite:
35 return `CREATE TABLE IF NOT EXISTS migrations (
36 id INTEGER PRIMARY KEY AUTOINCREMENT,
37 name TEXT NOT NULL,
38 version INTEGER NOT NULL UNIQUE
39 );
40 `
41 case "postgres":
42 return `CREATE TABLE IF NOT EXISTS migrations (
43 id SERIAL PRIMARY KEY,
44 name TEXT NOT NULL,
45 version INTEGER NOT NULL UNIQUE
46 );
47 `
48 case "mysql":
49 return `CREATE TABLE IF NOT EXISTS migrations (
50 id INT NOT NULL AUTO_INCREMENT,
51 name TEXT NOT NULL,
52 version INT NOT NULL,
53 UNIQUE (version),
54 PRIMARY KEY (id)
55 );
56 `
57 default:
58 panic("unknown driver")
59 }
60}
61
62// Migrate runs the migrations.
63func Migrate(ctx context.Context, dbx *db.DB) error {
64 logger := log.FromContext(ctx).WithPrefix("migrate")
65 return dbx.TransactionContext(ctx, func(tx *db.Tx) error {
66 if !hasTable(tx, "migrations") {
67 if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil {
68 return err
69 }
70 }
71
72 var migrs Migrations
73 if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
74 if !errors.Is(err, sql.ErrNoRows) {
75 return err
76 }
77 }
78
79 for _, m := range migrations {
80 if m.Version <= migrs.Version {
81 continue
82 }
83
84 logger.Infof("running migration %d. %s", m.Version, m.Name)
85 if err := m.Migrate(ctx, tx); err != nil {
86 return err
87 }
88
89 if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil {
90 return err
91 }
92 }
93
94 return nil
95 })
96}
97
98// Rollback rolls back a migration.
99func Rollback(ctx context.Context, dbx *db.DB) error {
100 logger := log.FromContext(ctx).WithPrefix("migrate")
101 return dbx.TransactionContext(ctx, func(tx *db.Tx) error {
102 var migrs Migrations
103 if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
104 if !errors.Is(err, sql.ErrNoRows) {
105 return fmt.Errorf("there are no migrations to rollback: %w", err)
106 }
107 }
108
109 if migrs.Version == 0 || len(migrations) < int(migrs.Version) {
110 return fmt.Errorf("there are no migrations to rollback")
111 }
112
113 m := migrations[migrs.Version-1]
114 logger.Infof("rolling back migration %d. %s", m.Version, m.Name)
115 if err := m.Rollback(ctx, tx); err != nil {
116 return err
117 }
118
119 if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE version = ?"), migrs.Version); err != nil {
120 return err
121 }
122
123 return nil
124 })
125}
126
127func hasTable(tx *db.Tx, tableName string) bool {
128 var query string
129 switch tx.DriverName() {
130 case driverSQLite3, driverSQLite:
131 query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
132 case driverPostgres:
133 fallthrough
134 case "mysql":
135 query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?"
136 }
137
138 query = tx.Rebind(query)
139 var name string
140 err := tx.Get(&name, query, tableName)
141 return err == nil
142}