1// SPDX-FileCopyrightText: Chris Waldon <christopher.waldon.dev@gmail.com>
2//
3// SPDX-License-Identifier: Apache-2.0
4
5package db
6
7import (
8 "context"
9 "database/sql"
10 _ "embed"
11 "fmt"
12)
13
14type migration struct {
15 upQuery string
16 downQuery string
17 postHook func(*sql.Tx) error
18}
19
20var (
21 //go:embed sql/1_add_project_ids.up.sql
22 migration1Up string
23 //go:embed sql/1_add_project_ids.down.sql
24 migration1Down string
25)
26
27var migrations = [...]migration{
28 0: {
29 upQuery: `CREATE TABLE schema_migrations (version uint64, dirty bool);
30 INSERT INTO schema_migrations (version, dirty) VALUES (0, 0);`,
31 downQuery: `DROP TABLE schema_migrations;`,
32 },
33 1: {
34 upQuery: migration1Up,
35 downQuery: migration1Down,
36 postHook: generateAndInsertProjectIDs,
37 },
38}
39
40// Migrate runs all pending migrations
41func Migrate(db *sql.DB) error {
42 version := getSchemaVersion(db)
43 for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ {
44 if err := runMigration(db, nextMigration); err != nil {
45 return fmt.Errorf("migrations failed: %w", err)
46 }
47 if version := getSchemaVersion(db); version != nextMigration {
48 return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version)
49 }
50 }
51 return nil
52}
53
54// runMigration runs a single migration inside a transaction, updates the schema
55// version and commits the transaction if successful, and rolls back the
56// transaction if unsuccessful.
57func runMigration(db *sql.DB, migrationIdx int) (err error) {
58 current := migrations[migrationIdx]
59 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
60 if err != nil {
61 return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err)
62 }
63 defer func() {
64 if err == nil {
65 err = tx.Commit()
66 }
67 if err != nil {
68 if rbErr := tx.Rollback(); rbErr != nil {
69 err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
70 }
71 }
72 }()
73 if len(current.upQuery) > 0 {
74 if _, err := tx.Exec(current.upQuery); err != nil {
75 return fmt.Errorf("failed running migration %d: %w", migrationIdx, err)
76 }
77 }
78 if current.postHook != nil {
79 if err := current.postHook(tx); err != nil {
80 return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err)
81 }
82 }
83 return updateSchemaVersion(tx, migrationIdx)
84}
85
86// undoMigration rolls the single most recent migration back inside a
87// transaction, updates the schema version and commits the transaction if
88// successful, and rolls back the transaction if unsuccessful.
89//
90//lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34)
91func undoMigration(db *sql.DB, migrationIdx int) (err error) {
92 current := migrations[migrationIdx]
93 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
94 if err != nil {
95 return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err)
96 }
97 defer func() {
98 if err == nil {
99 err = tx.Commit()
100 }
101 if err != nil {
102 if rbErr := tx.Rollback(); rbErr != nil {
103 err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
104 }
105 }
106 }()
107 if len(current.downQuery) > 0 {
108 if _, err := tx.Exec(current.downQuery); err != nil {
109 return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err)
110 }
111 }
112 return updateSchemaVersion(tx, migrationIdx-1)
113}
114
115// getSchemaVersion returns the schema version from the database
116func getSchemaVersion(db *sql.DB) int {
117 row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`)
118 var version int
119 if err := row.Scan(&version); err != nil {
120 version = -1
121 }
122 return version
123}
124
125// updateSchemaVersion sets the version to the provided int
126func updateSchemaVersion(tx *sql.Tx, version int) error {
127 if version < 0 {
128 // Do not try to use the schema_migrations table in a schema version where it doesn't exist
129 return nil
130 }
131 _, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version))
132 return err
133}