1package migrate
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8
9 "github.com/charmbracelet/log/v2"
10 "github.com/charmbracelet/soft-serve/pkg/db"
11)
12
13const (
14 postgresDriver = "postgres"
15 sqliteDriver = "sqlite"
16 sqlite3Driver = "sqlite3"
17)
18
19// MigrateFunc is a function that executes a migration.
20type MigrateFunc func(ctx context.Context, tx *db.Tx) error //nolint:revive
21
22// Migration is a struct that contains the name of the migration and the
23// function to execute it.
24type Migration struct {
25 Version int64
26 Name string
27 Migrate MigrateFunc
28 Rollback MigrateFunc
29}
30
31// Migrations is a database model to store migrations.
32type Migrations struct {
33 ID int64 `db:"id"`
34 Name string `db:"name"`
35 Version int64 `db:"version"`
36}
37
38func (Migrations) schema(driverName string) string {
39 switch driverName {
40 case sqlite3Driver, sqliteDriver:
41 return `CREATE TABLE IF NOT EXISTS migrations (
42 id INTEGER PRIMARY KEY AUTOINCREMENT,
43 name TEXT NOT NULL,
44 version INTEGER NOT NULL UNIQUE
45 );
46 `
47 case postgresDriver:
48 return `CREATE TABLE IF NOT EXISTS migrations (
49 id SERIAL PRIMARY KEY,
50 name TEXT NOT NULL,
51 version INTEGER NOT NULL UNIQUE
52 );
53 `
54 case "mysql":
55 return `CREATE TABLE IF NOT EXISTS migrations (
56 id INT NOT NULL AUTO_INCREMENT,
57 name TEXT NOT NULL,
58 version INT NOT NULL,
59 UNIQUE (version),
60 PRIMARY KEY (id)
61 );
62 `
63 default:
64 panic("unknown driver")
65 }
66}
67
68// Migrate runs the migrations.
69func Migrate(ctx context.Context, dbx *db.DB) error {
70 logger := log.FromContext(ctx).WithPrefix("migrate")
71 return dbx.TransactionContext(ctx, func(tx *db.Tx) error { //nolint:wrapcheck
72 if !hasTable(tx, "migrations") {
73 if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil {
74 return err //nolint:wrapcheck
75 }
76 }
77
78 var migrs Migrations
79 if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
80 if !errors.Is(err, sql.ErrNoRows) {
81 return err //nolint:wrapcheck
82 }
83 }
84
85 for _, m := range migrations {
86 if m.Version <= migrs.Version {
87 continue
88 }
89
90 logger.Infof("running migration %d. %s", m.Version, m.Name)
91 if err := m.Migrate(ctx, tx); err != nil {
92 return err
93 }
94
95 if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil {
96 return err //nolint:wrapcheck
97 }
98 }
99
100 return nil
101 })
102}
103
104// Rollback rolls back a migration.
105func Rollback(ctx context.Context, dbx *db.DB) error {
106 logger := log.FromContext(ctx).WithPrefix("migrate")
107 return dbx.TransactionContext(ctx, func(tx *db.Tx) error { //nolint:wrapcheck
108 var migrs Migrations
109 if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
110 if !errors.Is(err, sql.ErrNoRows) {
111 return fmt.Errorf("there are no migrations to rollback: %w", err)
112 }
113 }
114
115 if migrs.Version == 0 || len(migrations) < int(migrs.Version) {
116 return fmt.Errorf("there are no migrations to rollback")
117 }
118
119 m := migrations[migrs.Version-1]
120 logger.Infof("rolling back migration %d. %s", m.Version, m.Name)
121 if err := m.Rollback(ctx, tx); err != nil {
122 return err
123 }
124
125 if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE version = ?"), migrs.Version); err != nil {
126 return err //nolint:wrapcheck
127 }
128
129 return nil
130 })
131}
132
133func hasTable(tx *db.Tx, tableName string) bool {
134 var query string
135 switch tx.DriverName() {
136 case sqlite3Driver, sqliteDriver:
137 query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
138 case postgresDriver:
139 fallthrough
140 case "mysql":
141 query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?"
142 }
143
144 query = tx.Rebind(query)
145 var name string
146 err := tx.Get(&name, query, tableName)
147 return err == nil
148}