migrations.go

  1// SPDX-FileCopyrightText: Chris Waldon <christopher.waldon.dev@gmail.com>
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5package db
  6
  7import (
  8	"context"
  9	"database/sql"
 10	_ "embed"
 11	"fmt"
 12)
 13
 14type migration struct {
 15	upQuery   string
 16	downQuery string
 17	postHook  func(*sql.Tx) error
 18}
 19
 20var (
 21	//go:embed sql/1_add_project_ids.up.sql
 22	migration1Up string
 23	//go:embed sql/1_add_project_ids.down.sql
 24	migration1Down string
 25)
 26
 27var migrations = [...]migration{
 28	0: {
 29		upQuery: `CREATE TABLE schema_migrations (version uint64, dirty bool);
 30		INSERT INTO schema_migrations (version, dirty) VALUES (0, 0);`,
 31		downQuery: `DROP TABLE schema_migrations;`,
 32	},
 33	1: {
 34		upQuery:   migration1Up,
 35		downQuery: migration1Down,
 36		postHook:  generateAndInsertProjectIDs,
 37	},
 38}
 39
 40// Migrate runs all pending migrations
 41func Migrate(db *sql.DB) error {
 42	version := getSchemaVersion(db)
 43	for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ {
 44		if err := runMigration(db, nextMigration); err != nil {
 45			return fmt.Errorf("migrations failed: %w", err)
 46		}
 47		if version := getSchemaVersion(db); version != nextMigration {
 48			return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version)
 49		}
 50	}
 51	return nil
 52}
 53
 54// runMigration runs a single migration inside a transaction, updates the schema
 55// version and commits the transaction if successful, and rolls back the
 56// transaction if unsuccessful.
 57func runMigration(db *sql.DB, migrationIdx int) (err error) {
 58	current := migrations[migrationIdx]
 59	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
 60	if err != nil {
 61		return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err)
 62	}
 63	defer func() {
 64		if err == nil {
 65			err = tx.Commit()
 66		}
 67		if err != nil {
 68			if rbErr := tx.Rollback(); rbErr != nil {
 69				err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
 70			}
 71		}
 72	}()
 73	if len(current.upQuery) > 0 {
 74		if _, err := tx.Exec(current.upQuery); err != nil {
 75			return fmt.Errorf("failed running migration %d: %w", migrationIdx, err)
 76		}
 77	}
 78	if current.postHook != nil {
 79		if err := current.postHook(tx); err != nil {
 80			return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err)
 81		}
 82	}
 83	return updateSchemaVersion(tx, migrationIdx)
 84}
 85
 86// undoMigration rolls the single most recent migration back inside a
 87// transaction, updates the schema version and commits the transaction if
 88// successful, and rolls back the transaction if unsuccessful.
 89//
 90//lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34)
 91func undoMigration(db *sql.DB, migrationIdx int) (err error) {
 92	current := migrations[migrationIdx]
 93	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
 94	if err != nil {
 95		return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err)
 96	}
 97	defer func() {
 98		if err == nil {
 99			err = tx.Commit()
100		}
101		if err != nil {
102			if rbErr := tx.Rollback(); rbErr != nil {
103				err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
104			}
105		}
106	}()
107	if len(current.downQuery) > 0 {
108		if _, err := tx.Exec(current.downQuery); err != nil {
109			return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err)
110		}
111	}
112	return updateSchemaVersion(tx, migrationIdx-1)
113}
114
115// getSchemaVersion returns the schema version from the database
116func getSchemaVersion(db *sql.DB) int {
117	row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`)
118	var version int
119	if err := row.Scan(&version); err != nil {
120		version = -1
121	}
122	return version
123}
124
125// updateSchemaVersion sets the version to the provided int
126func updateSchemaVersion(tx *sql.Tx, version int) error {
127	if version < 0 {
128		// Do not try to use the schema_migrations table in a schema version where it doesn't exist
129		return nil
130	}
131	_, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version))
132	return err
133}