1// SPDX-FileCopyrightText: Chris Waldon <christopher.waldon.dev@gmail.com>
2//
3// SPDX-License-Identifier: Apache-2.0
4
5package db
6
7import (
8 "context"
9 "database/sql"
10 _ "embed"
11 "fmt"
12)
13
14type migration struct {
15 upQuery string
16 downQuery string
17 postHook func(*sql.Tx) error
18}
19
20var (
21 //go:embed sql/1_add_project_ids.up.sql
22 migration1Up string
23 //go:embed sql/1_add_project_ids.down.sql
24 migration1Down string
25 //go:embed sql/2_swap_project_url_for_id.up.sql
26 migration2Up string
27 //go:embed sql/2_swap_project_url_for_id.down.sql
28 migration2Down string
29)
30
31var migrations = [...]migration{
32 0: {
33 upQuery: `CREATE TABLE schema_migrations (version uint64, dirty bool);
34 INSERT INTO schema_migrations (version, dirty) VALUES (0, 0);`,
35 downQuery: `DROP TABLE schema_migrations;`,
36 postHook: nil,
37 },
38 1: {
39 upQuery: migration1Up,
40 downQuery: migration1Down,
41 postHook: generateAndInsertProjectIDs,
42 },
43 2: {
44 upQuery: migration2Up,
45 downQuery: migration2Down,
46 postHook: nil,
47 },
48 3: {
49 upQuery: "",
50 downQuery: "",
51 postHook: correctProjectIDs,
52 },
53}
54
55// Migrate runs all pending migrations.
56func Migrate(db *sql.DB) error {
57 version := getSchemaVersion(db)
58 for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ {
59 if err := runMigration(db, nextMigration); err != nil {
60 return fmt.Errorf("migrations failed: %w", err)
61 }
62
63 if version := getSchemaVersion(db); version != nextMigration {
64 return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version)
65 }
66 }
67
68 return nil
69}
70
71// runMigration runs a single migration inside a transaction, updates the schema
72// version and commits the transaction if successful, and rolls back the
73// transaction if unsuccessful.
74func runMigration(db *sql.DB, migrationIdx int) (err error) {
75 current := migrations[migrationIdx]
76
77 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{
78 Isolation: 0,
79 ReadOnly: false,
80 })
81 if err != nil {
82 return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err)
83 }
84
85 defer func() {
86 if err == nil {
87 err = tx.Commit()
88 }
89
90 if err != nil {
91 if rbErr := tx.Rollback(); rbErr != nil {
92 err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
93 }
94 }
95 }()
96
97 if len(current.upQuery) > 0 {
98 if _, err := tx.Exec(current.upQuery); err != nil {
99 return fmt.Errorf("failed running migration %d: %w", migrationIdx, err)
100 }
101 }
102
103 if current.postHook != nil {
104 if err := current.postHook(tx); err != nil {
105 return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err)
106 }
107 }
108
109 return updateSchemaVersion(tx, migrationIdx)
110}
111
112// undoMigration rolls the single most recent migration back inside a
113// transaction, updates the schema version and commits the transaction if
114// successful, and rolls back the transaction if unsuccessful.
115//
116//lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34)
117func undoMigration(db *sql.DB, migrationIdx int) (err error) {
118 current := migrations[migrationIdx]
119
120 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{
121 Isolation: 0,
122 ReadOnly: false,
123 })
124 if err != nil {
125 return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err)
126 }
127
128 defer func() {
129 if err == nil {
130 err = tx.Commit()
131 }
132
133 if err != nil {
134 if rbErr := tx.Rollback(); rbErr != nil {
135 err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
136 }
137 }
138 }()
139
140 if len(current.downQuery) > 0 {
141 if _, err := tx.Exec(current.downQuery); err != nil {
142 return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err)
143 }
144 }
145
146 return updateSchemaVersion(tx, migrationIdx-1)
147}
148
149// getSchemaVersion returns the schema version from the database.
150func getSchemaVersion(db *sql.DB) int {
151 row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`)
152
153 var version int
154 if err := row.Scan(&version); err != nil {
155 version = -1
156 }
157
158 return version
159}
160
161// updateSchemaVersion sets the version to the provided int.
162func updateSchemaVersion(tx *sql.Tx, version int) error {
163 if version < 0 {
164 // Do not try to use the schema_migrations table in a schema version where it doesn't exist
165 return nil
166 }
167
168 _, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version))
169
170 return err
171}