migrations.go

  1// SPDX-FileCopyrightText: Chris Waldon <christopher.waldon.dev@gmail.com>
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5package db
  6
  7import (
  8	"context"
  9	"database/sql"
 10	_ "embed"
 11	"fmt"
 12)
 13
 14type migration struct {
 15	upQuery   string
 16	downQuery string
 17	postHook  func(*sql.Tx) error
 18}
 19
 20var (
 21	//go:embed sql/1_add_project_ids.up.sql
 22	migration1Up string
 23	//go:embed sql/1_add_project_ids.down.sql
 24	migration1Down string
 25	//go:embed sql/2_swap_project_url_for_id.up.sql
 26	migration2Up string
 27	//go:embed sql/2_swap_project_url_for_id.down.sql
 28	migration2Down string
 29)
 30
 31var migrations = [...]migration{
 32	0: {
 33		upQuery: `CREATE TABLE schema_migrations (version uint64, dirty bool);
 34		INSERT INTO schema_migrations (version, dirty) VALUES (0, 0);`,
 35		downQuery: `DROP TABLE schema_migrations;`,
 36		postHook:  nil,
 37	},
 38	1: {
 39		upQuery:   migration1Up,
 40		downQuery: migration1Down,
 41		postHook:  generateAndInsertProjectIDs,
 42	},
 43	2: {
 44		upQuery:   migration2Up,
 45		downQuery: migration2Down,
 46		postHook:  nil,
 47	},
 48	3: {
 49		upQuery:   "",
 50		downQuery: "",
 51		postHook:  correctProjectIDs,
 52	},
 53}
 54
 55// Migrate runs all pending migrations.
 56func Migrate(db *sql.DB) error {
 57	version := getSchemaVersion(db)
 58	for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ {
 59		if err := runMigration(db, nextMigration); err != nil {
 60			return fmt.Errorf("migrations failed: %w", err)
 61		}
 62
 63		if version := getSchemaVersion(db); version != nextMigration {
 64			return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version)
 65		}
 66	}
 67
 68	return nil
 69}
 70
 71// runMigration runs a single migration inside a transaction, updates the schema
 72// version and commits the transaction if successful, and rolls back the
 73// transaction if unsuccessful.
 74func runMigration(db *sql.DB, migrationIdx int) (err error) {
 75	current := migrations[migrationIdx]
 76
 77	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{
 78		Isolation: 0,
 79		ReadOnly:  false,
 80	})
 81	if err != nil {
 82		return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err)
 83	}
 84
 85	defer func() {
 86		if err == nil {
 87			err = tx.Commit()
 88		}
 89
 90		if err != nil {
 91			if rbErr := tx.Rollback(); rbErr != nil {
 92				err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
 93			}
 94		}
 95	}()
 96
 97	if len(current.upQuery) > 0 {
 98		if _, err := tx.Exec(current.upQuery); err != nil {
 99			return fmt.Errorf("failed running migration %d: %w", migrationIdx, err)
100		}
101	}
102
103	if current.postHook != nil {
104		if err := current.postHook(tx); err != nil {
105			return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err)
106		}
107	}
108
109	return updateSchemaVersion(tx, migrationIdx)
110}
111
112// undoMigration rolls the single most recent migration back inside a
113// transaction, updates the schema version and commits the transaction if
114// successful, and rolls back the transaction if unsuccessful.
115//
116//lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34)
117func undoMigration(db *sql.DB, migrationIdx int) (err error) {
118	current := migrations[migrationIdx]
119
120	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{
121		Isolation: 0,
122		ReadOnly:  false,
123	})
124	if err != nil {
125		return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err)
126	}
127
128	defer func() {
129		if err == nil {
130			err = tx.Commit()
131		}
132
133		if err != nil {
134			if rbErr := tx.Rollback(); rbErr != nil {
135				err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
136			}
137		}
138	}()
139
140	if len(current.downQuery) > 0 {
141		if _, err := tx.Exec(current.downQuery); err != nil {
142			return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err)
143		}
144	}
145
146	return updateSchemaVersion(tx, migrationIdx-1)
147}
148
149// getSchemaVersion returns the schema version from the database.
150func getSchemaVersion(db *sql.DB) int {
151	row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`)
152
153	var version int
154	if err := row.Scan(&version); err != nil {
155		version = -1
156	}
157
158	return version
159}
160
161// updateSchemaVersion sets the version to the provided int.
162func updateSchemaVersion(tx *sql.Tx, version int) error {
163	if version < 0 {
164		// Do not try to use the schema_migrations table in a schema version where it doesn't exist
165		return nil
166	}
167
168	_, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version))
169	if err != nil {
170		return fmt.Errorf("failed to execute SQL: %w", err)
171	}
172
173	return nil
174}