Implement migration system, add first migration

Amolith created

Thank you for the help Chris!
https://github.com/whereswaldon

Change summary

cmd/willow.go                     |  20 ++--
db/db.go                          |  40 +++++---
db/migrations.go                  | 133 +++++++++++++++++++++++++++++++++
db/posthooks.go                   |  57 ++++++++++++++
db/sql/1_add_project_ids.down.sql |  26 ++++++
db/sql/1_add_project_ids.up.sql   |  14 +++
project/project.go                |   3 
7 files changed, 267 insertions(+), 26 deletions(-)

Detailed changes

cmd/willow.go 🔗

@@ -63,18 +63,18 @@ func main() {
 		os.Exit(1)
 	}
 
-	fmt.Println("Verifying database schema")
-	err = db.VerifySchema(dbConn)
+	fmt.Println("Checking whether database needs initialising")
+	err = db.InitialiseDatabase(dbConn)
 	if err != nil {
-		fmt.Println("Error verifying database schema:", err)
-		fmt.Println("Attempting to load schema")
-		err = db.LoadSchema(dbConn)
-		if err != nil {
-			fmt.Println("Error loading schema:", err)
-			os.Exit(1)
-		}
+		fmt.Println("Error initialising database:", err)
+		os.Exit(1)
+	}
+	fmt.Println("Checking whether there are pending migrations")
+	err = db.Migrate(dbConn)
+	if err != nil {
+		fmt.Println("Error migrating database schema:", err)
+		os.Exit(1)
 	}
-	fmt.Println("Database schema verified")
 
 	if len(*flagAddUser) > 0 && len(*flagDeleteUser) == 0 && !*flagListUsers && len(*flagCheckAuthorised) == 0 {
 		createUser(dbConn, *flagAddUser)

db/db.go 🔗

@@ -6,46 +6,54 @@ package db
 
 import (
 	"database/sql"
-	"embed"
+	_ "embed"
 
 	_ "modernc.org/sqlite"
 )
 
-// Embed the schema into the binary
-//
-//go:embed sql
-var embeddedSQL embed.FS
+//go:embed sql/schema.sql
+var schema string
 
 // Open opens a connection to the SQLite database
 func Open(dbPath string) (*sql.DB, error) {
 	return sql.Open("sqlite", dbPath)
 }
 
-func VerifySchema(dbConn *sql.DB) error {
+// VerifySchema checks whether the schema has been initalised and initialises it
+// if not
+func InitialiseDatabase(dbConn *sql.DB) error {
+	var name string
+	err := dbConn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'").Scan(&name)
+	if err == nil {
+		return nil
+	}
+
 	tables := []string{
 		"users",
 		"sessions",
 		"projects",
+		"releases",
 	}
 
 	for _, table := range tables {
 		name := ""
-		err := dbConn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&name)
+		err := dbConn.QueryRow(
+			"SELECT name FROM sqlite_master WHERE type='table' AND name=@table",
+			sql.Named("table", table),
+		).Scan(&name)
 		if err != nil {
-			return err
+			if err = loadSchema(dbConn); err != nil {
+				return err
+			}
 		}
 	}
 	return nil
 }
 
-// LoadSchema loads the schema into the database
-func LoadSchema(dbConn *sql.DB) error {
-	schema, err := embeddedSQL.ReadFile("sql/schema.sql")
-	if err != nil {
+// loadSchema loads the initial schema into the database
+func loadSchema(dbConn *sql.DB) error {
+	if _, err := dbConn.Exec(schema); err != nil {
 		return err
 	}
-
-	_, err = dbConn.Exec(string(schema))
-
-	return err
+	return nil
 }

db/migrations.go 🔗

@@ -0,0 +1,133 @@
+// SPDX-FileCopyrightText: Chris Waldon <christopher.waldon.dev@gmail.com>
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+	"context"
+	"database/sql"
+	_ "embed"
+	"fmt"
+)
+
+type migration struct {
+	upQuery   string
+	downQuery string
+	postHook  func(*sql.Tx) error
+}
+
+var (
+	//go:embed sql/1_add_project_ids.up.sql
+	migration1Up string
+	//go:embed sql/1_add_project_ids.down.sql
+	migration1Down string
+)
+
+var migrations = [...]migration{
+	0: {
+		upQuery: `CREATE TABLE schema_migrations (version uint64, dirty bool);
+		INSERT INTO schema_migrations (version, dirty) VALUES (0, 0);`,
+		downQuery: `DROP TABLE schema_migrations;`,
+	},
+	1: {
+		upQuery:   migration1Up,
+		downQuery: migration1Down,
+		postHook:  generateAndInsertProjectIDs,
+	},
+}
+
+// Migrate runs all pending migrations
+func Migrate(db *sql.DB) error {
+	version := getSchemaVersion(db)
+	for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ {
+		if err := runMigration(db, nextMigration); err != nil {
+			return fmt.Errorf("migrations failed: %w", err)
+		}
+		if version := getSchemaVersion(db); version != nextMigration {
+			return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version)
+		}
+	}
+	return nil
+}
+
+// runMigration runs a single migration inside a transaction, updates the schema
+// version and commits the transaction if successful, and rolls back the
+// transaction if unsuccessful.
+func runMigration(db *sql.DB, migrationIdx int) (err error) {
+	current := migrations[migrationIdx]
+	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
+	if err != nil {
+		return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err)
+	}
+	defer func() {
+		if err == nil {
+			err = tx.Commit()
+		}
+		if err != nil {
+			if rbErr := tx.Rollback(); rbErr != nil {
+				err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
+			}
+		}
+	}()
+	if len(current.upQuery) > 0 {
+		if _, err := tx.Exec(current.upQuery); err != nil {
+			return fmt.Errorf("failed running migration %d: %w", migrationIdx, err)
+		}
+	}
+	if current.postHook != nil {
+		if err := current.postHook(tx); err != nil {
+			return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err)
+		}
+	}
+	return updateSchemaVersion(tx, migrationIdx)
+}
+
+// undoMigration rolls the single most recent migration back inside a
+// transaction, updates the schema version and commits the transaction if
+// successful, and rolls back the transaction if unsuccessful.
+//
+//lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34)
+func undoMigration(db *sql.DB, migrationIdx int) (err error) {
+	current := migrations[migrationIdx]
+	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
+	if err != nil {
+		return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err)
+	}
+	defer func() {
+		if err == nil {
+			err = tx.Commit()
+		}
+		if err != nil {
+			if rbErr := tx.Rollback(); rbErr != nil {
+				err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
+			}
+		}
+	}()
+	if len(current.downQuery) > 0 {
+		if _, err := tx.Exec(current.downQuery); err != nil {
+			return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err)
+		}
+	}
+	return updateSchemaVersion(tx, migrationIdx-1)
+}
+
+// getSchemaVersion returns the schema version from the database
+func getSchemaVersion(db *sql.DB) int {
+	row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`)
+	var version int
+	if err := row.Scan(&version); err != nil {
+		version = -1
+	}
+	return version
+}
+
+// updateSchemaVersion sets the version to the provided int
+func updateSchemaVersion(tx *sql.Tx, version int) error {
+	if version < 0 {
+		// Do not try to use the schema_migrations table in a schema version where it doesn't exist
+		return nil
+	}
+	_, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version))
+	return err
+}

db/posthooks.go 🔗

@@ -0,0 +1,57 @@
+// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+	"crypto/sha256"
+	"database/sql"
+	"fmt"
+)
+
+// generateAndInsertProjectIDs runs during migration 1, fetches all rows from
+// projects_tmp, loops through the rows generating a repeatable ID for each
+// project, and inserting it into the new table along with the data from the old
+// table.
+func generateAndInsertProjectIDs(tx *sql.Tx) error {
+	// Loop through projects_tmp, generate a project_id for each, and insert
+	// into projects
+	rows, err := tx.Query("SELECT url, name, forge, version, created_at FROM projects_tmp")
+	if err != nil {
+		return fmt.Errorf("failed to list projects in projects_tmp: %w", err)
+	}
+	defer rows.Close()
+
+	for rows.Next() {
+		var (
+			url        string
+			name       string
+			forge      string
+			version    string
+			created_at string
+		)
+		if err := rows.Scan(&url, &name, &forge, &version, &created_at); err != nil {
+			return fmt.Errorf("failed to scan row from projects_tmp: %w", err)
+		}
+		id := fmt.Sprintf("%x", sha256.Sum256([]byte(url+name+forge+created_at)))
+		_, err = tx.Exec(
+			"INSERT INTO projects (id, url, name, forge, version, created_at) VALUES (@id, @url, @name, @forge, @version, @created_at)",
+			sql.Named("id", id),
+			sql.Named("url", url),
+			sql.Named("name", name),
+			sql.Named("forge", forge),
+			sql.Named("version", version),
+			sql.Named("created_at", created_at),
+		)
+		if err != nil {
+			return fmt.Errorf("failed to insert project into projects: %w", err)
+		}
+	}
+
+	if _, err := tx.Exec("DROP TABLE projects_tmp"); err != nil {
+		return fmt.Errorf("failed to drop projects_tmp: %w", err)
+	}
+
+	return nil
+}

db/sql/1_add_project_ids.down.sql 🔗

@@ -0,0 +1,26 @@
+-- SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+--
+-- SPDX-License-Identifier: CC0-1.0
+
+--ALTER TABLE projects RENAME TO projects_tmp; -- noqa
+
+ALTER TABLE projects RENAME TO projects_tmp;
+
+CREATE TABLE IF NOT EXISTS projects (
+    url TEXT NOT NULL PRIMARY KEY,
+    name TEXT NOT NULL,
+    forge TEXT NOT NULL,
+    version TEXT NOT NULL,
+    created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+INSERT INTO projects (url, name, forge, version, created_at)
+SELECT
+    url,
+    name,
+    forge,
+    version,
+    created_at
+FROM projects_tmp;
+
+DROP TABLE projects_tmp;

db/sql/1_add_project_ids.up.sql 🔗

@@ -0,0 +1,14 @@
+-- SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+--
+-- SPDX-License-Identifier: CC0-1.0
+
+ALTER TABLE projects RENAME TO projects_tmp;
+
+CREATE TABLE IF NOT EXISTS projects (
+    id TEXT NOT NULL PRIMARY KEY,
+    url TEXT NOT NULL,
+    name TEXT NOT NULL,
+    forge TEXT NOT NULL,
+    version TEXT NOT NULL,
+    created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+);

project/project.go 🔗

@@ -29,6 +29,7 @@ type Project struct {
 }
 
 type Release struct {
+	ID      string
 	URL     string
 	Tag     string
 	Content string
@@ -70,6 +71,7 @@ func fetchReleases(dbConn *sql.DB, p Project) (Project, error) {
 		}
 		for _, release := range rssReleases {
 			p.Releases = append(p.Releases, Release{
+				ID:      genReleaseID(p.URL, release.URL, release.Tag),
 				Tag:     release.Tag,
 				Content: release.Content,
 				URL:     release.URL,
@@ -88,6 +90,7 @@ func fetchReleases(dbConn *sql.DB, p Project) (Project, error) {
 		}
 		for _, release := range gitReleases {
 			p.Releases = append(p.Releases, Release{
+				ID:      genReleaseID(p.URL, release.URL, release.Tag),
 				Tag:     release.Tag,
 				Content: release.Content,
 				URL:     release.URL,