Detailed changes
@@ -63,18 +63,18 @@ func main() {
os.Exit(1)
}
- fmt.Println("Verifying database schema")
- err = db.VerifySchema(dbConn)
+ fmt.Println("Checking whether database needs initialising")
+ err = db.InitialiseDatabase(dbConn)
if err != nil {
- fmt.Println("Error verifying database schema:", err)
- fmt.Println("Attempting to load schema")
- err = db.LoadSchema(dbConn)
- if err != nil {
- fmt.Println("Error loading schema:", err)
- os.Exit(1)
- }
+ fmt.Println("Error initialising database:", err)
+ os.Exit(1)
+ }
+ fmt.Println("Checking whether there are pending migrations")
+ err = db.Migrate(dbConn)
+ if err != nil {
+ fmt.Println("Error migrating database schema:", err)
+ os.Exit(1)
}
- fmt.Println("Database schema verified")
if len(*flagAddUser) > 0 && len(*flagDeleteUser) == 0 && !*flagListUsers && len(*flagCheckAuthorised) == 0 {
createUser(dbConn, *flagAddUser)
@@ -6,46 +6,54 @@ package db
import (
"database/sql"
- "embed"
+ _ "embed"
_ "modernc.org/sqlite"
)
-// Embed the schema into the binary
-//
-//go:embed sql
-var embeddedSQL embed.FS
+//go:embed sql/schema.sql
+var schema string
// Open opens a connection to the SQLite database
func Open(dbPath string) (*sql.DB, error) {
return sql.Open("sqlite", dbPath)
}
-func VerifySchema(dbConn *sql.DB) error {
+// VerifySchema checks whether the schema has been initalised and initialises it
+// if not
+func InitialiseDatabase(dbConn *sql.DB) error {
+ var name string
+ err := dbConn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'").Scan(&name)
+ if err == nil {
+ return nil
+ }
+
tables := []string{
"users",
"sessions",
"projects",
+ "releases",
}
for _, table := range tables {
name := ""
- err := dbConn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&name)
+ err := dbConn.QueryRow(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name=@table",
+ sql.Named("table", table),
+ ).Scan(&name)
if err != nil {
- return err
+ if err = loadSchema(dbConn); err != nil {
+ return err
+ }
}
}
return nil
}
-// LoadSchema loads the schema into the database
-func LoadSchema(dbConn *sql.DB) error {
- schema, err := embeddedSQL.ReadFile("sql/schema.sql")
- if err != nil {
+// loadSchema loads the initial schema into the database
+func loadSchema(dbConn *sql.DB) error {
+ if _, err := dbConn.Exec(schema); err != nil {
return err
}
-
- _, err = dbConn.Exec(string(schema))
-
- return err
+ return nil
}
@@ -0,0 +1,133 @@
+// SPDX-FileCopyrightText: Chris Waldon <christopher.waldon.dev@gmail.com>
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "context"
+ "database/sql"
+ _ "embed"
+ "fmt"
+)
+
+type migration struct {
+ upQuery string
+ downQuery string
+ postHook func(*sql.Tx) error
+}
+
+var (
+ //go:embed sql/1_add_project_ids.up.sql
+ migration1Up string
+ //go:embed sql/1_add_project_ids.down.sql
+ migration1Down string
+)
+
+var migrations = [...]migration{
+ 0: {
+ upQuery: `CREATE TABLE schema_migrations (version uint64, dirty bool);
+ INSERT INTO schema_migrations (version, dirty) VALUES (0, 0);`,
+ downQuery: `DROP TABLE schema_migrations;`,
+ },
+ 1: {
+ upQuery: migration1Up,
+ downQuery: migration1Down,
+ postHook: generateAndInsertProjectIDs,
+ },
+}
+
+// Migrate runs all pending migrations
+func Migrate(db *sql.DB) error {
+ version := getSchemaVersion(db)
+ for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ {
+ if err := runMigration(db, nextMigration); err != nil {
+ return fmt.Errorf("migrations failed: %w", err)
+ }
+ if version := getSchemaVersion(db); version != nextMigration {
+ return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version)
+ }
+ }
+ return nil
+}
+
+// runMigration runs a single migration inside a transaction, updates the schema
+// version and commits the transaction if successful, and rolls back the
+// transaction if unsuccessful.
+func runMigration(db *sql.DB, migrationIdx int) (err error) {
+ current := migrations[migrationIdx]
+ tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
+ if err != nil {
+ return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err)
+ }
+ defer func() {
+ if err == nil {
+ err = tx.Commit()
+ }
+ if err != nil {
+ if rbErr := tx.Rollback(); rbErr != nil {
+ err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
+ }
+ }
+ }()
+ if len(current.upQuery) > 0 {
+ if _, err := tx.Exec(current.upQuery); err != nil {
+ return fmt.Errorf("failed running migration %d: %w", migrationIdx, err)
+ }
+ }
+ if current.postHook != nil {
+ if err := current.postHook(tx); err != nil {
+ return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err)
+ }
+ }
+ return updateSchemaVersion(tx, migrationIdx)
+}
+
+// undoMigration rolls the single most recent migration back inside a
+// transaction, updates the schema version and commits the transaction if
+// successful, and rolls back the transaction if unsuccessful.
+//
+//lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34)
+func undoMigration(db *sql.DB, migrationIdx int) (err error) {
+ current := migrations[migrationIdx]
+ tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
+ if err != nil {
+ return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err)
+ }
+ defer func() {
+ if err == nil {
+ err = tx.Commit()
+ }
+ if err != nil {
+ if rbErr := tx.Rollback(); rbErr != nil {
+ err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
+ }
+ }
+ }()
+ if len(current.downQuery) > 0 {
+ if _, err := tx.Exec(current.downQuery); err != nil {
+ return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err)
+ }
+ }
+ return updateSchemaVersion(tx, migrationIdx-1)
+}
+
+// getSchemaVersion returns the schema version from the database
+func getSchemaVersion(db *sql.DB) int {
+ row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`)
+ var version int
+ if err := row.Scan(&version); err != nil {
+ version = -1
+ }
+ return version
+}
+
+// updateSchemaVersion sets the version to the provided int
+func updateSchemaVersion(tx *sql.Tx, version int) error {
+ if version < 0 {
+ // Do not try to use the schema_migrations table in a schema version where it doesn't exist
+ return nil
+ }
+ _, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version))
+ return err
+}
@@ -0,0 +1,57 @@
+// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "crypto/sha256"
+ "database/sql"
+ "fmt"
+)
+
+// generateAndInsertProjectIDs runs during migration 1, fetches all rows from
+// projects_tmp, loops through the rows generating a repeatable ID for each
+// project, and inserting it into the new table along with the data from the old
+// table.
+func generateAndInsertProjectIDs(tx *sql.Tx) error {
+ // Loop through projects_tmp, generate a project_id for each, and insert
+ // into projects
+ rows, err := tx.Query("SELECT url, name, forge, version, created_at FROM projects_tmp")
+ if err != nil {
+ return fmt.Errorf("failed to list projects in projects_tmp: %w", err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var (
+ url string
+ name string
+ forge string
+ version string
+ created_at string
+ )
+ if err := rows.Scan(&url, &name, &forge, &version, &created_at); err != nil {
+ return fmt.Errorf("failed to scan row from projects_tmp: %w", err)
+ }
+ id := fmt.Sprintf("%x", sha256.Sum256([]byte(url+name+forge+created_at)))
+ _, err = tx.Exec(
+ "INSERT INTO projects (id, url, name, forge, version, created_at) VALUES (@id, @url, @name, @forge, @version, @created_at)",
+ sql.Named("id", id),
+ sql.Named("url", url),
+ sql.Named("name", name),
+ sql.Named("forge", forge),
+ sql.Named("version", version),
+ sql.Named("created_at", created_at),
+ )
+ if err != nil {
+ return fmt.Errorf("failed to insert project into projects: %w", err)
+ }
+ }
+
+ if _, err := tx.Exec("DROP TABLE projects_tmp"); err != nil {
+ return fmt.Errorf("failed to drop projects_tmp: %w", err)
+ }
+
+ return nil
+}
@@ -0,0 +1,26 @@
+-- SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+--
+-- SPDX-License-Identifier: CC0-1.0
+
+--ALTER TABLE projects RENAME TO projects_tmp; -- noqa
+
+ALTER TABLE projects RENAME TO projects_tmp;
+
+CREATE TABLE IF NOT EXISTS projects (
+ url TEXT NOT NULL PRIMARY KEY,
+ name TEXT NOT NULL,
+ forge TEXT NOT NULL,
+ version TEXT NOT NULL,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+INSERT INTO projects (url, name, forge, version, created_at)
+SELECT
+ url,
+ name,
+ forge,
+ version,
+ created_at
+FROM projects_tmp;
+
+DROP TABLE projects_tmp;
@@ -0,0 +1,14 @@
+-- SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+--
+-- SPDX-License-Identifier: CC0-1.0
+
+ALTER TABLE projects RENAME TO projects_tmp;
+
+CREATE TABLE IF NOT EXISTS projects (
+ id TEXT NOT NULL PRIMARY KEY,
+ url TEXT NOT NULL,
+ name TEXT NOT NULL,
+ forge TEXT NOT NULL,
+ version TEXT NOT NULL,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
@@ -29,6 +29,7 @@ type Project struct {
}
type Release struct {
+ ID string
URL string
Tag string
Content string
@@ -70,6 +71,7 @@ func fetchReleases(dbConn *sql.DB, p Project) (Project, error) {
}
for _, release := range rssReleases {
p.Releases = append(p.Releases, Release{
+ ID: genReleaseID(p.URL, release.URL, release.Tag),
Tag: release.Tag,
Content: release.Content,
URL: release.URL,
@@ -88,6 +90,7 @@ func fetchReleases(dbConn *sql.DB, p Project) (Project, error) {
}
for _, release := range gitReleases {
p.Releases = append(p.Releases, Release{
+ ID: genReleaseID(p.URL, release.URL, release.Tag),
Tag: release.Tag,
Content: release.Content,
URL: release.URL,