db.go

  1package db
  2
  3import (
  4	"database/sql"
  5	"embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"regexp"
 10	"sort"
 11	"strconv"
 12
 13	_ "modernc.org/sqlite"
 14)
 15
 16//go:generate go tool github.com/sqlc-dev/sqlc/cmd/sqlc generate
 17
 18//go:embed migrations/*.sql
 19var migrationFS embed.FS
 20
 21// Open opens an sqlite database and prepares pragmas suitable for a small web app.
 22func Open(path string) (*sql.DB, error) {
 23	db, err := sql.Open("sqlite", path)
 24	if err != nil {
 25		return nil, err
 26	}
 27	// Light pragmas similar
 28	if _, err := db.Exec("PRAGMA foreign_keys=ON;"); err != nil {
 29		_ = db.Close()
 30		return nil, fmt.Errorf("enable foreign keys: %w", err)
 31	}
 32	if _, err := db.Exec("PRAGMA journal_mode=wal;"); err != nil {
 33		_ = db.Close()
 34		return nil, fmt.Errorf("set WAL: %w", err)
 35	}
 36	if _, err := db.Exec("PRAGMA busy_timeout=1000;"); err != nil {
 37		_ = db.Close()
 38		return nil, fmt.Errorf("set busy_timeout: %w", err)
 39	}
 40	return db, nil
 41}
 42
 43// RunMigrations executes database migrations in numeric order (NNN-*.sql),
 44// similar in spirit to exed's exedb.RunMigrations.
 45func RunMigrations(db *sql.DB) error {
 46	entries, err := migrationFS.ReadDir("migrations")
 47	if err != nil {
 48		return fmt.Errorf("read migrations dir: %w", err)
 49	}
 50	var migrations []string
 51	pat := regexp.MustCompile(`^(\d{3})-.*\.sql$`)
 52	for _, e := range entries {
 53		if e.IsDir() {
 54			continue
 55		}
 56		name := e.Name()
 57		if pat.MatchString(name) {
 58			migrations = append(migrations, name)
 59		}
 60	}
 61	sort.Strings(migrations)
 62
 63	executed := make(map[int]bool)
 64	var tableName string
 65	err = db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations'").Scan(&tableName)
 66	switch {
 67	case err == nil:
 68		rows, err := db.Query("SELECT migration_number FROM migrations")
 69		if err != nil {
 70			return fmt.Errorf("query executed migrations: %w", err)
 71		}
 72		defer rows.Close()
 73		for rows.Next() {
 74			var n int
 75			if err := rows.Scan(&n); err != nil {
 76				return fmt.Errorf("scan migration number: %w", err)
 77			}
 78			executed[n] = true
 79		}
 80	case errors.Is(err, sql.ErrNoRows):
 81		slog.Info("db: migrations table not found; running all migrations")
 82	default:
 83		return fmt.Errorf("check migrations table: %w", err)
 84	}
 85
 86	for _, m := range migrations {
 87		match := pat.FindStringSubmatch(m)
 88		if len(match) != 2 {
 89			return fmt.Errorf("invalid migration filename: %s", m)
 90		}
 91		n, err := strconv.Atoi(match[1])
 92		if err != nil {
 93			return fmt.Errorf("parse migration number %s: %w", m, err)
 94		}
 95		if executed[n] {
 96			continue
 97		}
 98		if err := executeMigration(db, m); err != nil {
 99			return fmt.Errorf("execute %s: %w", m, err)
100		}
101		slog.Info("db: applied migration", "file", m, "number", n)
102	}
103	return nil
104}
105
106func executeMigration(db *sql.DB, filename string) error {
107	content, err := migrationFS.ReadFile("migrations/" + filename)
108	if err != nil {
109		return fmt.Errorf("read %s: %w", filename, err)
110	}
111	if _, err := db.Exec(string(content)); err != nil {
112		return fmt.Errorf("exec %s: %w", filename, err)
113	}
114	return nil
115}