1package migrate
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8
9 "github.com/charmbracelet/log"
10 "github.com/charmbracelet/soft-serve/pkg/db"
11)
12
13// MigrateFunc is a function that executes a migration.
14type MigrateFunc func(ctx context.Context, h db.Handler) error // nolint:revive
15
16// Migration is a struct that contains the name of the migration and the
17// function to execute it.
18type Migration struct {
19 Version int64
20 Name string
21 PreMigrate MigrateFunc
22 Migrate MigrateFunc
23 PostMigrate MigrateFunc
24 Rollback MigrateFunc
25}
26
27// Migrations is a database model to store migrations.
28type Migrations struct {
29 ID int64 `db:"id"`
30 Name string `db:"name"`
31 Version int64 `db:"version"`
32}
33
34func (Migrations) schema(driverName string) string {
35 switch driverName {
36 case "sqlite3", "sqlite":
37 return `CREATE TABLE IF NOT EXISTS migrations (
38 id INTEGER PRIMARY KEY AUTOINCREMENT,
39 name TEXT NOT NULL,
40 version INTEGER NOT NULL UNIQUE
41 );
42 `
43 case "postgres":
44 return `CREATE TABLE IF NOT EXISTS migrations (
45 id SERIAL PRIMARY KEY,
46 name TEXT NOT NULL,
47 version INTEGER NOT NULL UNIQUE
48 );
49 `
50 case "mysql":
51 return `CREATE TABLE IF NOT EXISTS migrations (
52 id INT NOT NULL AUTO_INCREMENT,
53 name TEXT NOT NULL,
54 version INT NOT NULL,
55 UNIQUE (version),
56 PRIMARY KEY (id)
57 );
58 `
59 default:
60 panic("unknown driver")
61 }
62}
63
64// Migrate runs the migrations.
65func Migrate(ctx context.Context, dbx *db.DB) error {
66 logger := log.FromContext(ctx).WithPrefix("migrate")
67
68 if !hasTable(dbx, "migrations") {
69 if _, err := dbx.Exec(Migrations{}.schema(dbx.DriverName())); err != nil {
70 return err
71 }
72 }
73
74 var migrs Migrations
75 if err := dbx.Get(&migrs, dbx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
76 if !errors.Is(err, sql.ErrNoRows) {
77 return err
78 }
79 }
80
81 for _, m := range migrations {
82 if m.Version <= migrs.Version {
83 continue
84 }
85
86 logger.Infof("running migration %d. %s", m.Version, m.Name)
87 if m.PreMigrate != nil {
88 if err := m.PreMigrate(ctx, dbx); err != nil {
89 return err
90 }
91 }
92
93 if err := dbx.TransactionContext(ctx, func(tx *db.Tx) error {
94 return m.Migrate(ctx, tx)
95 }); err != nil {
96 return err
97 }
98
99 if m.PostMigrate != nil {
100 if err := m.PostMigrate(ctx, dbx); err != nil {
101 return err
102 }
103 }
104
105 if _, err := dbx.Exec(dbx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil {
106 return err
107 }
108 }
109
110 return nil
111}
112
113// Rollback rolls back a migration.
114func Rollback(ctx context.Context, dbx *db.DB) error {
115 logger := log.FromContext(ctx).WithPrefix("migrate")
116 return dbx.TransactionContext(ctx, func(tx *db.Tx) error {
117 var migrs Migrations
118 if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
119 if !errors.Is(err, sql.ErrNoRows) {
120 return fmt.Errorf("there are no migrations to rollback: %w", err)
121 }
122 }
123
124 if migrs.Version == 0 || len(migrations) < int(migrs.Version) {
125 return fmt.Errorf("there are no migrations to rollback")
126 }
127
128 m := migrations[migrs.Version-1]
129 logger.Infof("rolling back migration %d. %s", m.Version, m.Name)
130 if err := m.Rollback(ctx, tx); err != nil {
131 return err
132 }
133
134 if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE version = ?"), migrs.Version); err != nil {
135 return err
136 }
137
138 return nil
139 })
140}
141
142func hasTable(tx db.Handler, tableName string) bool {
143 var query string
144 switch tx.DriverName() {
145 case "sqlite3", "sqlite":
146 query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
147 case "postgres":
148 fallthrough
149 case "mysql":
150 query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?"
151 }
152
153 query = tx.Rebind(query)
154 var name string
155 err := tx.Get(&name, query, tableName)
156 return err == nil
157}